fix conflict
|
@ -1,8 +1,9 @@
|
|||
include LICENSE.txt
|
||||
include LICENSE
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include ppocr/utils/e2e_utils/ *.py
|
||||
recursive-include ppocr/utils/e2e_utils *.py
|
||||
recursive-include ppstructure *.py
|
|
@ -92,7 +92,7 @@ class WindowMixin(object):
|
|||
class MainWindow(QMainWindow, WindowMixin):
|
||||
FIT_WINDOW, FIT_WIDTH, MANUAL_ZOOM = list(range(3))
|
||||
|
||||
def __init__(self, lang="ch", defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None):
|
||||
def __init__(self, lang="ch", gpu=False, defaultFilename=None, defaultPrefdefClassFile=None, defaultSaveDir=None):
|
||||
super(MainWindow, self).__init__()
|
||||
self.setWindowTitle(__appname__)
|
||||
|
||||
|
@ -108,7 +108,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
getStr = lambda strId: self.stringBundle.getString(strId)
|
||||
|
||||
self.defaultSaveDir = defaultSaveDir
|
||||
self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=False, lang=lang)
|
||||
self.ocr = PaddleOCR(use_pdserving=False, use_angle_cls=True, det=True, cls=True, use_gpu=gpu, lang=lang)
|
||||
|
||||
if os.path.exists('./data/paddle.png'):
|
||||
result = self.ocr.ocr('./data/paddle.png', cls=True, det=True)
|
||||
|
@ -398,6 +398,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
||||
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
||||
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
||||
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
|
||||
|
||||
zoom = QWidgetAction(self)
|
||||
zoom.setDefaultWidget(self.zoomWidget)
|
||||
|
@ -565,7 +566,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
addActions(self.menus.file,
|
||||
(opendir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
|
||||
addActions(self.menus.help, (showSteps, showInfo))
|
||||
addActions(self.menus.help, (showKeys,showSteps, showInfo))
|
||||
addActions(self.menus.view, (
|
||||
self.displayLabelOption, self.labelDialogOption,
|
||||
None,
|
||||
|
@ -760,6 +761,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
msg = stepsInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def showKeysDialog(self):
|
||||
msg = keysInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def createShape(self):
|
||||
assert self.beginner()
|
||||
self.canvas.setEditing(False)
|
||||
|
@ -1239,6 +1244,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
|
||||
def loadFile(self, filePath=None):
|
||||
"""Load the specified file, or the last opened file if None."""
|
||||
if self.dirty:
|
||||
self.mayContinue()
|
||||
self.resetState()
|
||||
self.canvas.setEnabled(False)
|
||||
if filePath is None:
|
||||
|
@ -2037,6 +2044,8 @@ def read(filename, default=None):
|
|||
except:
|
||||
return default
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
def get_main_app(argv=[]):
|
||||
"""
|
||||
|
@ -2048,13 +2057,14 @@ def get_main_app(argv=[]):
|
|||
app.setWindowIcon(newIcon("app"))
|
||||
# Tzutalin 201705+: Accept extra agruments to change predefined class file
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--lang", default='en', nargs="?")
|
||||
argparser.add_argument("--lang", type=str, default='en', nargs="?")
|
||||
argparser.add_argument("--gpu", type=str2bool, default=False, nargs="?")
|
||||
argparser.add_argument("--predefined_classes_file",
|
||||
default=os.path.join(os.path.dirname(__file__), "data", "predefined_classes.txt"),
|
||||
nargs="?")
|
||||
args = argparser.parse_args(argv[1:])
|
||||
# Usage : labelImg.py image predefClassFile saveDir
|
||||
win = MainWindow(lang=args.lang,
|
||||
win = MainWindow(lang=args.lang, gpu=args.gpu,
|
||||
defaultPrefdefClassFile=args.predefined_classes_file)
|
||||
win.show()
|
||||
return app, win
|
||||
|
|
|
@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
|
|||
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
||||
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
||||
"识别标签保存在*rec_gt.txt*中。\n"
|
||||
|
||||
else:
|
||||
msg = "1. Build and launch using the instructions above.\n" \
|
||||
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
||||
|
@ -187,5 +188,57 @@ def stepsInfo(lang='en'):
|
|||
"8. Click 'Save', the image status will switch to '√',then the program automatically jump to the next.\n"\
|
||||
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
||||
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
|
||||
return msg
|
||||
|
||||
def keysInfo(lang='en'):
|
||||
if lang == 'ch':
|
||||
msg = "快捷键\t\t\t说明\n" \
|
||||
"———————————————————————\n"\
|
||||
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
|
||||
"W\t\t\t新建矩形框\n" \
|
||||
"Q\t\t\t新建四点框\n" \
|
||||
"Ctrl + E\t\t编辑所选框标签\n" \
|
||||
"Ctrl + R\t\t重新识别所选标记\n" \
|
||||
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
|
||||
"Ctrl + 鼠标左键\t\t多选标记框\n" \
|
||||
"Backspace\t\t删除所选框\n" \
|
||||
"Ctrl + V\t\t确认本张图片标记\n" \
|
||||
"Ctrl + Shift + d\t删除本张图片\n" \
|
||||
"D\t\t\t下一张图片\n" \
|
||||
"A\t\t\t上一张图片\n" \
|
||||
"Ctrl++\t\t\t缩小\n" \
|
||||
"Ctrl--\t\t\t放大\n" \
|
||||
"↑→↓←\t\t\t移动标记框\n" \
|
||||
"———————————————————————\n" \
|
||||
"注:Mac用户Command键替换上述Ctrl键"
|
||||
|
||||
else:
|
||||
msg = "Shortcut Keys\t\tDescription\n" \
|
||||
"———————————————————————\n" \
|
||||
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
|
||||
"\t\t\tof the current image\n" \
|
||||
"\n"\
|
||||
"W\t\t\tCreate a rect box\n" \
|
||||
"Q\t\t\tCreate a four-points box\n" \
|
||||
"Ctrl + E\t\tEdit label of the selected box\n" \
|
||||
"Ctrl + R\t\tRe-recognize the selected box\n" \
|
||||
"Ctrl + C\t\tCopy and paste the selected\n" \
|
||||
"\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Ctrl + Left Mouse\tMulti select the label\n" \
|
||||
"Button\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Backspace\t\tDelete the selected box\n" \
|
||||
"Ctrl + V\t\tCheck image\n" \
|
||||
"Ctrl + Shift + d\tDelete image\n" \
|
||||
"D\t\t\tNext image\n" \
|
||||
"A\t\t\tPrevious image\n" \
|
||||
"Ctrl++\t\t\tZoom in\n" \
|
||||
"Ctrl--\t\t\tZoom out\n" \
|
||||
"↑→↓←\t\t\tMove selected box" \
|
||||
"———————————————————————\n" \
|
||||
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
|
||||
|
||||
return msg
|
|
@ -89,6 +89,7 @@ saveRec=保存识别结果
|
|||
tempLabel=待识别
|
||||
nullLabel=无法识别
|
||||
steps=操作步骤
|
||||
keys=快捷键
|
||||
choseModelLg=选择模型语言
|
||||
cancel=取消
|
||||
ok=确认
|
||||
|
|
|
@ -89,6 +89,7 @@ saveRec=Save Recognition Result
|
|||
tempLabel=TEMPORARY
|
||||
nullLabel=NULL
|
||||
steps=Steps
|
||||
keys=Shortcut Keys
|
||||
choseModelLg=Choose Model Language
|
||||
cancel=Cancel
|
||||
ok=OK
|
||||
|
|
|
@ -66,6 +66,7 @@ class StdTextDrawer(object):
|
|||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
i = 0
|
||||
break
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
|
@ -78,7 +79,6 @@ class StdTextDrawer(object):
|
|||
|
||||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
break
|
||||
|
||||
return corpus_list, text_input_list
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddleocr
|
||||
from .paddleocr import *
|
||||
|
||||
__all__ = ['PaddleOCR', 'draw_ocr']
|
||||
from .paddleocr import PaddleOCR
|
||||
from .tools.infer.utility import draw_ocr
|
||||
__version__ = paddleocr.VERSION
|
||||
__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- ["Student2", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2", "Teacher"]
|
||||
# key: maps
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,174 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,176 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,159 @@
|
|||
Global:
|
||||
debug: false
|
||||
use_gpu: true
|
||||
epoch_num: 800
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: true
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: false
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
|
||||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 2.0e-05
|
||||
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationCTCLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDMLLoss:
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out
|
||||
- DistillationDistanceLoss:
|
||||
weight: 1.0
|
||||
mode: "l2"
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out
|
||||
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode
|
||||
model_name: ["Student", "Teacher"]
|
||||
key: head_out
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: RecMetric
|
||||
main_indicator: acc
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list:
|
||||
- ./train_data/train_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- RecAug:
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: true
|
||||
batch_size_per_card: 128
|
||||
drop_last: true
|
||||
num_sections: 1
|
||||
num_workers: 8
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data
|
||||
label_file_list:
|
||||
- ./train_data/val_list.txt
|
||||
transforms:
|
||||
- DecodeImage:
|
||||
img_mode: BGR
|
||||
channel_first: false
|
||||
- CTCLabelEncode:
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys:
|
||||
- image
|
||||
- label
|
||||
- length
|
||||
loader:
|
||||
shuffle: false
|
||||
drop_last: false
|
||||
batch_size_per_card: 128
|
||||
num_workers: 8
|
|
@ -0,0 +1,116 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 50
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/table_mv3/
|
||||
save_epoch_step: 5
|
||||
# evaluation is run every 400 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 400]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 100
|
||||
max_elem_length: 500
|
||||
max_cell_num: 500
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: table
|
||||
algorithm: TableAttn
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 1.0
|
||||
model_name: small
|
||||
disable_se: True
|
||||
Head:
|
||||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
l2_decay: 0.00001
|
||||
loc_type: 2
|
||||
|
||||
Loss:
|
||||
name: TableAttentionLoss
|
||||
structure_weight: 100.0
|
||||
loc_weight: 10000.0
|
||||
|
||||
PostProcess:
|
||||
name: TableLabelDecode
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/train/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 32
|
||||
drop_last: True
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/val/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 1
|
|
@ -465,8 +465,12 @@ public class MainActivity extends AppCompatActivity {
|
|||
}
|
||||
|
||||
public void btn_load_model_click(View view) {
|
||||
tvStatus.setText("STATUS: load model ......");
|
||||
loadModel();
|
||||
if (predictor.isLoaded()){
|
||||
tvStatus.setText("STATUS: model has been loaded");
|
||||
}else{
|
||||
tvStatus.setText("STATUS: load model ......");
|
||||
loadModel();
|
||||
}
|
||||
}
|
||||
|
||||
public void btn_run_model_click(View view) {
|
||||
|
|
|
@ -194,26 +194,25 @@ public class Predictor {
|
|||
"supported!");
|
||||
return false;
|
||||
}
|
||||
int[] channelStride = new int[]{width * height, width * height * 2};
|
||||
int p = scaleImage.getPixel(scaleImage.getWidth() - 1, scaleImage.getHeight() - 1);
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int color = scaleImage.getPixel(x, y);
|
||||
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
|
||||
(float) blue(color) / 255.0f};
|
||||
inputData[y * width + x] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
|
||||
inputData[y * width + x + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
|
||||
inputData[y * width + x + channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
|
||||
|
||||
}
|
||||
int[] channelStride = new int[]{width * height, width * height * 2};
|
||||
int[] pixels=new int[width*height];
|
||||
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
|
||||
for (int i = 0; i < pixels.length; i++) {
|
||||
int color = pixels[i];
|
||||
float[] rgb = new float[]{(float) red(color) / 255.0f, (float) green(color) / 255.0f,
|
||||
(float) blue(color) / 255.0f};
|
||||
inputData[i] = (rgb[channelIdx[0]] - inputMean[0]) / inputStd[0];
|
||||
inputData[i + channelStride[0]] = (rgb[channelIdx[1]] - inputMean[1]) / inputStd[1];
|
||||
inputData[i+ channelStride[1]] = (rgb[channelIdx[2]] - inputMean[2]) / inputStd[2];
|
||||
}
|
||||
} else if (channels == 1) {
|
||||
for (int y = 0; y < height; y++) {
|
||||
for (int x = 0; x < width; x++) {
|
||||
int color = inputImage.getPixel(x, y);
|
||||
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
|
||||
inputData[y * width + x] = (gray - inputMean[0]) / inputStd[0];
|
||||
}
|
||||
int[] pixels=new int[width*height];
|
||||
scaleImage.getPixels(pixels,0,scaleImage.getWidth(),0,0,scaleImage.getWidth(),scaleImage.getHeight());
|
||||
for (int i = 0; i < pixels.length; i++) {
|
||||
int color = pixels[i];
|
||||
float gray = (float) (red(color) + green(color) + blue(color)) / 3.0f / 255.0f;
|
||||
inputData[i] = (gray - inputMean[0]) / inputStd[0];
|
||||
}
|
||||
} else {
|
||||
Log.i(TAG, "Unsupported channel size " + Integer.toString(channels) + ", only channel 1 and 3 is " +
|
||||
|
|
|
@ -13,7 +13,6 @@ SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
|||
|
||||
set(DEMO_NAME "ocr_system")
|
||||
|
||||
|
||||
macro(safe_set_static_flag)
|
||||
foreach(flag_var
|
||||
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
|
||||
|
|
|
@ -14,7 +14,7 @@ PaddleOCR在Windows 平台下基于`Visual Studio 2019 Community` 进行了测
|
|||
|
||||
### Step1: 下载PaddlePaddle C++ 预测库 fluid_inference
|
||||
|
||||
PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/windows_cpp_inference.html)
|
||||
PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#windows)
|
||||
|
||||
解压后`D:\projects\fluid_inference`目录包含内容为:
|
||||
```
|
||||
|
@ -93,3 +93,5 @@ cd D:\projects\PaddleOCR\deploy\cpp_infer\out\build\x64-Release
|
|||
|
||||
### 注意
|
||||
* 在Windows下的终端中执行文件exe时,可能会发生乱码的现象,此时需要在终端中输入`CHCP 65001`,将终端的编码方式由GBK编码(默认)改为UTF-8编码,更加具体的解释可以参考这篇博客:[https://blog.csdn.net/qq_35038153/article/details/78430359](https://blog.csdn.net/qq_35038153/article/details/78430359)。
|
||||
|
||||
* 编译时,如果报错`错误:C1083 无法打开包括文件:"dirent.h":No such file or directory`,可以参考该[文档](https://blog.csdn.net/Dora_blank/article/details/117740837#41_C1083_direnthNo_such_file_or_directory_54),新建`dirent.h`文件,并添加到`VC++`的包含目录中。
|
||||
|
|
|
@ -18,6 +18,7 @@ PaddleOCR模型部署。
|
|||
* 首先需要从opencv官网上下载在Linux环境下源码编译的包,以opencv3.4.7为例,下载命令如下。
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
@ -141,11 +142,11 @@ build/paddle_inference_install_dir/
|
|||
```
|
||||
inference/
|
||||
|-- det_db
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdimodel
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
|-- rec_rcnn
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
|
@ -184,7 +185,7 @@ cmake .. \
|
|||
make -j
|
||||
```
|
||||
|
||||
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。
|
||||
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。**注意**:以上路径都写绝对路径,不要写相对路径。
|
||||
|
||||
|
||||
* 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。
|
||||
|
|
|
@ -18,6 +18,7 @@ PaddleOCR model deployment.
|
|||
* First of all, you need to download the source code compiled package in the Linux environment from the opencv official website. Taking opencv3.4.7 as an example, the download command is as follows.
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
@ -144,11 +145,11 @@ Among them, `paddle` is the Paddle library required for C++ prediction later, an
|
|||
```
|
||||
inference/
|
||||
|-- det_db
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdimodel
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
|-- rec_rcnn
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdparams
|
||||
| |--inference.pdiparams
|
||||
| |--inference.pdmodel
|
||||
```
|
||||
|
||||
|
||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||
std::memset(e, 0, sizeof(TEdge));
|
||||
std::memset(e, int(0), sizeof(TEdge));
|
||||
e->Next = eNext;
|
||||
e->Prev = ePrev;
|
||||
e->Curr = Pt;
|
||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
|||
TEdge *rb = lm->RightBound;
|
||||
|
||||
OutPt *Op1 = 0;
|
||||
if (!lb) {
|
||||
if (!lb || !rb) {
|
||||
// nb: don't insert LB into either AEL or SEL
|
||||
InsertEdgeIntoAEL(rb, 0);
|
||||
SetWindingCount(*rb);
|
||||
if (IsContributing(*rb))
|
||||
Op1 = AddOutPt(rb, rb->Bot);
|
||||
} else if (!rb) {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
SetWindingCount(*lb);
|
||||
if (IsContributing(*lb))
|
||||
Op1 = AddOutPt(lb, lb->Bot);
|
||||
//} else if (!rb) {
|
||||
// InsertEdgeIntoAEL(lb, 0);
|
||||
// SetWindingCount(*lb);
|
||||
// if (IsContributing(*lb))
|
||||
// Op1 = AddOutPt(lb, lb->Bot);
|
||||
InsertScanbeam(lb->Top.Y);
|
||||
} else {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
if (dir == dLeftToRight) {
|
||||
maxIt = m_Maxima.begin();
|
||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||
maxIt = m_Maxima.end();
|
||||
} else {
|
||||
maxRit = m_Maxima.rbegin();
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||
maxRit = m_Maxima.rend();
|
||||
}
|
||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
}
|
||||
} else {
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
e /= 255.0;
|
||||
}
|
||||
(*im).convertTo(*im, CV_32FC3, e);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
|
||||
}
|
||||
std::vector<cv::Mat> bgr_channels(3);
|
||||
cv::split(*im, bgr_channels);
|
||||
for (auto i = 0; i < bgr_channels.size(); i++) {
|
||||
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
|
||||
(0.0 - mean[i]) * scale[i]);
|
||||
}
|
||||
cv::merge(bgr_channels, *im);
|
||||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
|
|
|
@ -29,7 +29,8 @@ deploy/hubserving/ocr_system/
|
|||
### 1. 准备环境
|
||||
```shell
|
||||
# 安装paddlehub
|
||||
pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# paddlehub 需要 python>3.6.2
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. 下载推理模型
|
||||
|
|
|
@ -30,7 +30,8 @@ The following steps take the 2-stage series service as an example. If only the d
|
|||
### 1. Prepare the environment
|
||||
```shell
|
||||
# Install paddlehub
|
||||
pip3 install paddlehub --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# python>3.6.2 is required bt paddlehub
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. Download inference model
|
||||
|
|
|
@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
|
|||
from ppocr.data import build_dataloader
|
||||
|
||||
|
||||
def export_single_model(quanter, model, infer_shape, save_path, logger):
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
############################################################################################################
|
||||
# 1. quantization configs
|
||||
|
@ -76,14 +87,21 @@ def main():
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
# get QAT model
|
||||
quanter = QAT(config=quant_config)
|
||||
quanter.quantize(model)
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
model.eval()
|
||||
|
||||
# build metric
|
||||
|
@ -92,25 +110,30 @@ def main():
|
|||
# build dataloader
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
# start eval
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metirc.items():
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||
'model_type'] != "det" else [3, 640, 640]
|
||||
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(quanter, model.model_list[idx], infer_shape,
|
||||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(quanter, model, infer_shape, save_path, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
if config['Global']['distributed']:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
|
@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
|
|
|
@ -111,9 +111,9 @@
|
|||
| 字段 | 用途 | 默认值 | 备注 |
|
||||
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
|
||||
| **dataset** | 每次迭代返回一个样本 | - | - |
|
||||
| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDateSet` |
|
||||
| name | dataset类名 | SimpleDataSet | 目前支持`SimpleDataSet`和`LMDBDataSet` |
|
||||
| data_dir | 数据集图片存放路径 | ./train_data | \ |
|
||||
| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDateSet时不需要此参数 |
|
||||
| label_file_list | 数据标签路径 | ["./train_data/train_list.txt"] | dataset为LMDBDataSet时不需要此参数 |
|
||||
| ratio_list | 数据集的比例 | [1.0] | 若label_file_list中有两个train_list,且ratio_list为[0.4,0.6],则从train_list1中采样40%,从train_list2中采样60%组合整个dataset |
|
||||
| transforms | 对图片和标签进行变换的方法列表 | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | 见[ppocr/data/imaug](../../ppocr/data/imaug) |
|
||||
| **loader** | dataloader相关 | - | |
|
||||
|
|
|
@ -18,9 +18,9 @@ PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支
|
|||
|
||||
```
|
||||
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
|
||||
python gen_label.py --mode="det" --root_path="icdar_c4_train_imgs/" \
|
||||
--input_path="ch4_training_localization_transcription_gt" \
|
||||
--output_label="train_icdar2015_label.txt"
|
||||
python gen_label.py --mode="det" --root_path="/path/to/icdar_c4_train_imgs/" \
|
||||
--input_path="/path/to/ch4_training_localization_transcription_gt" \
|
||||
--output_label="/path/to/train_icdar2015_label.txt"
|
||||
```
|
||||
|
||||
解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
|
||||
|
|
|
@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
|||
|
||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
如果想使用CPU进行预测,执行命令如下
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
|
@ -221,7 +221,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Gl
|
|||
|
||||
```
|
||||
|
||||
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,**可以执行如下命令:
|
||||
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
```
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
# 知识蒸馏
|
||||
|
||||
|
||||
## 1. 简介
|
||||
|
||||
### 1.1 知识蒸馏介绍
|
||||
|
||||
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
|
||||
|
||||
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
|
||||
|
||||
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
|
||||
|
||||
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
|
||||
|
||||
### 1.2 PaddleOCR知识蒸馏简介
|
||||
|
||||
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
|
||||
|
||||
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
|
||||
|
||||
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
|
||||
|
||||
PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
|
||||
- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
|
||||
- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
|
||||
- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
|
||||
|
||||
|
||||
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
|
||||
|
||||
|
||||
|
||||
## 2. 配置文件解析
|
||||
|
||||
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
|
||||
|
||||
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
|
||||
|
||||
### 2.1 识别配置文件解析
|
||||
|
||||
配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)。
|
||||
|
||||
#### 2.1.1 模型结构
|
||||
|
||||
知识蒸馏任务中,模型结构配置如下所示。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
|
||||
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
|
||||
algorithm: Distillation # 算法名称
|
||||
Models: # 模型,包含子网络的配置信息
|
||||
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
|
||||
pretrained: # 该子网络是否需要加载预训练模型
|
||||
freeze_params: false # 是否需要固定参数
|
||||
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
|
||||
model_type: *model_type # 模型类别
|
||||
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
|
||||
pretrained: # 下面的组网参数同上
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。
|
||||
|
||||
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。
|
||||
|
||||
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
|
||||
|
||||
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
|
||||
|
||||
```json
|
||||
{
|
||||
"Teacher": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
},
|
||||
"Student": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.1.2 损失函数
|
||||
|
||||
知识蒸馏任务中,损失函数配置如下所示。
|
||||
|
||||
```yaml
|
||||
Loss:
|
||||
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
|
||||
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
|
||||
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
|
||||
weight: 1.0 # 权重
|
||||
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
|
||||
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDistanceLoss: # 蒸馏的距离损失函数
|
||||
weight: 1.0 # 权重
|
||||
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
|
||||
model_name_pairs: # 用于计算distance loss的子网络名称对
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
|
||||
|
||||
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
|
||||
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
|
||||
|
||||
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。
|
||||
|
||||
|
||||
#### 2.1.3 后处理
|
||||
|
||||
知识蒸馏任务中,后处理配置如下所示。
|
||||
|
||||
```yaml
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
|
||||
|
||||
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
|
||||
|
||||
|
||||
#### 2.1.4 指标计算
|
||||
|
||||
知识蒸馏任务中,指标计算配置如下所示。
|
||||
|
||||
```yaml
|
||||
Metric:
|
||||
name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
|
||||
main_indicator: acc # 指标的名称
|
||||
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
|
||||
```
|
||||
|
||||
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
|
||||
|
||||
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
||||
|
||||
|
||||
### 2.2 检测配置文件解析
|
||||
|
||||
* coming soon!
|
|
@ -243,7 +243,7 @@ Optimizer:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
|
||||
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# 数据集路径
|
||||
data_dir: ./train_data/
|
||||
|
@ -263,7 +263,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
|
||||
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# 数据集路径
|
||||
data_dir: ./train_data
|
||||
|
@ -330,6 +330,8 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
```
|
||||
|
||||
意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
|
||||
|
||||
2. 手动修改配置文件
|
||||
|
||||
您也可以手动修改模版中的以下几个字段:
|
||||
|
@ -375,7 +377,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
|
||||
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
|
||||
* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
|
||||
* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
|
||||
|
||||
|
@ -393,7 +397,7 @@ Global:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
|
||||
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# 数据集路径
|
||||
data_dir: ./train_data/
|
||||
|
@ -403,7 +407,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
# 数据集格式,支持LMDBDateSet以及SimpleDataSet
|
||||
# 数据集格式,支持LMDBDataSet以及SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# 数据集路径
|
||||
data_dir: ./train_data
|
||||
|
|
|
@ -5,26 +5,32 @@
|
|||
### 1.1 安装whl包
|
||||
|
||||
pip安装
|
||||
|
||||
```bash
|
||||
pip install "paddleocr>=2.0.1" # 推荐使用2.0.1+版本
|
||||
```
|
||||
|
||||
本地构建并安装
|
||||
|
||||
```bash
|
||||
python3 setup.py bdist_wheel
|
||||
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
|
||||
```
|
||||
|
||||
## 2 使用
|
||||
|
||||
### 2.1 代码使用
|
||||
|
||||
paddleocr whl包会自动下载ppocr轻量级模型作为默认模型,可以根据第3节**自定义模型**进行自定义更换。
|
||||
|
||||
* 检测+方向分类器+识别全流程
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
|
@ -32,6 +38,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -40,31 +47,36 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
<div align="center">
|
||||
<img src="../imgs_results/whl/11_det_rec.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
* 检测+识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path)
|
||||
result = ocr.ocr(img_path, cls=False)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -73,38 +85,46 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
<div align="center">
|
||||
<img src="../imgs_results/whl/11_det_rec.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
* 方向分类器+识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行检测
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path, rec=False)
|
||||
for line in result:
|
||||
|
@ -118,13 +138,16 @@ im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/Pa
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含文本框
|
||||
|
||||
```bash
|
||||
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
|
||||
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
|
||||
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
|
||||
|
@ -133,29 +156,37 @@ im_show.save('result.jpg')
|
|||
</div>
|
||||
|
||||
* 单独执行识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行方向分类器
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False, rec=False, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含分类结果和分类置信度
|
||||
|
||||
```bash
|
||||
['0', 0.9999924]
|
||||
```
|
||||
|
@ -163,15 +194,19 @@ for line in result:
|
|||
### 2.2 通过命令行使用
|
||||
|
||||
查看帮助信息
|
||||
|
||||
```bash
|
||||
paddleocr -h
|
||||
```
|
||||
|
||||
* 检测+方向分类器+识别全流程
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
|
@ -180,10 +215,13 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
|
|||
```
|
||||
|
||||
* 检测+识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
|
@ -192,20 +230,25 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
|
|||
```
|
||||
|
||||
* 方向分类器+识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行检测
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含文本框
|
||||
|
||||
```bash
|
||||
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
|
||||
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
|
||||
|
@ -214,34 +257,42 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
|
|||
```
|
||||
|
||||
* 单独执行识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --det false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行方向分类器
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false --rec false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含分类结果和分类置信度
|
||||
|
||||
```bash
|
||||
['0', 0.9999924]
|
||||
```
|
||||
|
||||
## 3 自定义模型
|
||||
当内置模型无法满足需求时,需要使用到自己训练的模型。
|
||||
首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
|
||||
|
||||
当内置模型无法满足需求时,需要使用到自己训练的模型。 首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
|
||||
|
||||
### 3.1 代码使用
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# 模型路径下必须含有model和params文件
|
||||
ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}', rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}', use_angle_cls=True)
|
||||
ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}',
|
||||
rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}',
|
||||
use_angle_cls=True)
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
|
@ -249,6 +300,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -269,11 +321,13 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
### 4.1 网络图片
|
||||
|
||||
- 代码使用
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
|
@ -281,7 +335,9 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
download_with_progressbar(img_path, 'tmp.jpg')
|
||||
image = Image.open('tmp.jpg').convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
|
@ -289,18 +345,24 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
- 命令行模式
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
### 4.2 numpy数组
|
||||
|
||||
仅通过代码使用时支持numpy数组作为输入
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
|
||||
|
@ -310,6 +372,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -355,3 +418,5 @@ im_show.save('result.jpg')
|
|||
| det | 前向时使用启动检测 | TRUE |
|
||||
| rec | 前向时是否启动识别 | TRUE |
|
||||
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
|
||||
| show_log | 是否打印det和rec等信息 | FALSE |
|
||||
| type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr |
|
||||
|
|
|
@ -110,9 +110,9 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and
|
|||
| Parameter | Use | Defaults | Note |
|
||||
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
|
||||
| **dataset** | Return one sample per iteration | - | - |
|
||||
| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDateSet` |
|
||||
| name | dataset class name | SimpleDataSet | Currently support`SimpleDataSet`,`LMDBDataSet` |
|
||||
| data_dir | Image folder path | ./train_data | \ |
|
||||
| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDateSet |
|
||||
| label_file_list | Groundtruth file path | ["./train_data/train_list.txt"] | This parameter is not required when dataset is LMDBDataSet |
|
||||
| ratio_list | Ratio of data set | [1.0] | If there are two train_lists in label_file_list and ratio_list is [0.4,0.6], 40% will be sampled from train_list1, and 60% will be sampled from train_list2 to combine the entire dataset |
|
||||
| transforms | List of methods to transform images and labels | [DecodeImage,CTCLabelEncode,RecResizeImg,KeepKeys] | see[ppocr/data/imaug](../../ppocr/data/imaug) |
|
||||
| **loader** | dataloader related | - | |
|
||||
|
|
|
@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
|
|||
|
||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
If you want to use the CPU for prediction, execute the command as follows
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
|
@ -230,7 +230,7 @@ First, convert the model saved in the SAST text detection training process into
|
|||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
|
||||
```
|
||||
|
||||
**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
|
||||
For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
|
|
|
@ -237,7 +237,7 @@ Optimizer:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
# Type of dataset,we support LMDBDateSet and SimpleDataSet
|
||||
# Type of dataset,we support LMDBDataSet and SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# Path of dataset
|
||||
data_dir: ./train_data/
|
||||
|
@ -257,7 +257,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
# Type of dataset,we support LMDBDateSet and SimpleDataSet
|
||||
# Type of dataset,we support LMDBDataSet and SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# Path of dataset
|
||||
data_dir: ./train_data
|
||||
|
@ -329,6 +329,7 @@ There are two ways to create the required configuration file::
|
|||
...
|
||||
|
||||
```
|
||||
Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
|
||||
|
||||
2. Manually modify the configuration file
|
||||
|
||||
|
@ -375,7 +376,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
|
|||
|
||||
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
|
||||
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
|
||||
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
|
||||
|
||||
|
@ -394,7 +397,7 @@ Global:
|
|||
|
||||
Train:
|
||||
dataset:
|
||||
# Type of dataset,we support LMDBDateSet and SimpleDataSet
|
||||
# Type of dataset,we support LMDBDataSet and SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# Path of dataset
|
||||
data_dir: ./train_data/
|
||||
|
@ -404,7 +407,7 @@ Train:
|
|||
|
||||
Eval:
|
||||
dataset:
|
||||
# Type of dataset,we support LMDBDateSet and SimpleDataSet
|
||||
# Type of dataset,we support LMDBDataSet and SimpleDataSet
|
||||
name: SimpleDataSet
|
||||
# Path of dataset
|
||||
data_dir: ./train_data
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
||||
- 2020.5.30 Model prediction and training support on Windows system
|
||||
- 2020.5.30 Open source general Chinese OCR model
|
||||
|
|
|
@ -59,7 +59,7 @@ Visualization of results
|
|||
from paddleocr import PaddleOCR,draw_ocr
|
||||
ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
|
||||
result = ocr.ocr(img_path)
|
||||
result = ocr.ocr(img_path, cls=False)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
|
@ -305,7 +305,8 @@ paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-f
|
|||
Support numpy array as input only when used by code
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
|
@ -316,7 +317,9 @@ for line in result:
|
|||
|
||||
# show result
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
download_with_progressbar(img_path, 'tmp.jpg')
|
||||
image = Image.open('tmp.jpg').convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
|
@ -362,3 +365,5 @@ im_show.save('result.jpg')
|
|||
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
||||
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
||||
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
|
||||
| show_log | Whether to print log in det and rec | FALSE |
|
||||
| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
|
BIN
doc/joinus.PNG
Before Width: | Height: | Size: 100 KiB After Width: | Height: | Size: 212 KiB |
After Width: | Height: | Size: 263 KiB |
After Width: | Height: | Size: 672 KiB |
After Width: | Height: | Size: 672 KiB |
After Width: | Height: | Size: 1.5 MiB |
After Width: | Height: | Size: 1.4 MiB |
After Width: | Height: | Size: 2.5 MiB |
After Width: | Height: | Size: 521 KiB |
After Width: | Height: | Size: 146 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 552 KiB |
After Width: | Height: | Size: 416 KiB |
420
paddleocr.py
|
@ -19,254 +19,174 @@ __dir__ = os.path.dirname(__file__)
|
|||
sys.path.append(os.path.join(__dir__, ''))
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.infer import predict_system
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from tools.infer.utility import draw_ocr
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
|
||||
from tools.infer.utility import draw_ocr, str2bool
|
||||
from ppstructure.utility import init_args, draw_structure_result
|
||||
from ppstructure.predict_system import OCRSystem, save_structure_res
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
|
||||
|
||||
model_urls = {
|
||||
'det': {
|
||||
'ch':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'en':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
|
||||
},
|
||||
'rec': {
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||
},
|
||||
'en': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/en_dict.txt'
|
||||
},
|
||||
'french': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||
},
|
||||
'german': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||
},
|
||||
'korean': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||
},
|
||||
'japan': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||
},
|
||||
'chinese_cht': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
|
||||
},
|
||||
'ta': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ta_dict.txt'
|
||||
},
|
||||
'te': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/te_dict.txt'
|
||||
},
|
||||
'ka': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ka_dict.txt'
|
||||
},
|
||||
'latin': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/latin_dict.txt'
|
||||
},
|
||||
'arabic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
|
||||
},
|
||||
'cyrillic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
|
||||
},
|
||||
'devanagari': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
||||
},
|
||||
'structure': {
|
||||
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_dict.txt'
|
||||
}
|
||||
},
|
||||
'cls':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
|
||||
'table': {
|
||||
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
|
||||
}
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
VERSION = '2.1'
|
||||
VERSION = '2.2'
|
||||
SUPPORT_REC_MODEL = ['CRNN']
|
||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args(mMain=True, add_help=True):
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
parser.add_help = mMain
|
||||
parser.add_argument("--lang", type=str, default='ch')
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
parser.add_argument("--type", type=str, default='ocr')
|
||||
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
for action in parser._actions:
|
||||
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
|
||||
action.default = None
|
||||
if mMain:
|
||||
parser = argparse.ArgumentParser(add_help=add_help)
|
||||
# params for prediction engine
|
||||
parser.add_argument("--use_gpu", type=str2bool, default=True)
|
||||
parser.add_argument("--ir_optim", type=str2bool, default=True)
|
||||
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
|
||||
parser.add_argument("--gpu_mem", type=int, default=8000)
|
||||
|
||||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_dir", type=str, default=None)
|
||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||
|
||||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
parser.add_argument("--use_dilation", type=bool, default=False)
|
||||
parser.add_argument("--det_db_score_mode", type=str, default="fast")
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
|
||||
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
|
||||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument("--rec_model_dir", type=str, default=None)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=6)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument("--rec_char_dict_path", type=str, default=None)
|
||||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||
|
||||
# params for text classifier
|
||||
parser.add_argument("--cls_model_dir", type=str, default=None)
|
||||
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
|
||||
parser.add_argument("--label_list", type=list, default=['0', '180'])
|
||||
parser.add_argument("--cls_batch_num", type=int, default=6)
|
||||
parser.add_argument("--cls_thresh", type=float, default=0.9)
|
||||
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
|
||||
parser.add_argument("--use_pdserving", type=str2bool, default=False)
|
||||
|
||||
parser.add_argument("--lang", type=str, default='ch')
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
|
||||
return parser.parse_args()
|
||||
else:
|
||||
return argparse.Namespace(
|
||||
use_gpu=True,
|
||||
ir_optim=True,
|
||||
use_tensorrt=False,
|
||||
gpu_mem=8000,
|
||||
image_dir='',
|
||||
det_algorithm='DB',
|
||||
det_model_dir=None,
|
||||
det_limit_side_len=960,
|
||||
det_limit_type='max',
|
||||
det_db_thresh=0.3,
|
||||
det_db_box_thresh=0.5,
|
||||
det_db_unclip_ratio=1.6,
|
||||
use_dilation=False,
|
||||
det_db_score_mode="fast",
|
||||
det_east_score_thresh=0.8,
|
||||
det_east_cover_thresh=0.1,
|
||||
det_east_nms_thresh=0.2,
|
||||
rec_algorithm='CRNN',
|
||||
rec_model_dir=None,
|
||||
rec_image_shape="3, 32, 320",
|
||||
rec_char_type='ch',
|
||||
rec_batch_num=6,
|
||||
max_text_length=25,
|
||||
rec_char_dict_path=None,
|
||||
use_space_char=True,
|
||||
drop_score=0.5,
|
||||
cls_model_dir=None,
|
||||
cls_image_shape="3, 48, 192",
|
||||
label_list=['0', '180'],
|
||||
cls_batch_num=6,
|
||||
cls_thresh=0.9,
|
||||
enable_mkldnn=False,
|
||||
use_zero_copy_run=False,
|
||||
use_pdserving=False,
|
||||
lang='ch',
|
||||
det=True,
|
||||
rec=True,
|
||||
use_angle_cls=False)
|
||||
inference_args_dict = {}
|
||||
for action in parser._actions:
|
||||
inference_args_dict[action.dest] = action.default
|
||||
return argparse.Namespace(**inference_args_dict)
|
||||
|
||||
|
||||
def parse_lang(lang):
|
||||
latin_lang = [
|
||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
|
||||
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
|
||||
]
|
||||
arabic_lang = ['ar', 'fa', 'ug', 'ur']
|
||||
cyrillic_lang = [
|
||||
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
|
||||
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
|
||||
]
|
||||
devanagari_lang = [
|
||||
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
|
||||
'gom', 'sa', 'bgc'
|
||||
]
|
||||
if lang in latin_lang:
|
||||
lang = "latin"
|
||||
elif lang in arabic_lang:
|
||||
lang = "arabic"
|
||||
elif lang in cyrillic_lang:
|
||||
lang = "cyrillic"
|
||||
elif lang in devanagari_lang:
|
||||
lang = "devanagari"
|
||||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
if lang == "ch":
|
||||
det_lang = "ch"
|
||||
elif lang == 'structure':
|
||||
det_lang = 'structure'
|
||||
else:
|
||||
det_lang = "en"
|
||||
return lang, det_lang
|
||||
|
||||
|
||||
class PaddleOCR(predict_system.TextSystem):
|
||||
|
@ -276,77 +196,43 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
postprocess_params = parse_args(mMain=False, add_help=False)
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||
lang = postprocess_params.lang
|
||||
latin_lang = [
|
||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||
'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
|
||||
'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
|
||||
]
|
||||
arabic_lang = ['ar', 'fa', 'ug', 'ur']
|
||||
cyrillic_lang = [
|
||||
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
|
||||
'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
|
||||
]
|
||||
devanagari_lang = [
|
||||
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
|
||||
'gom', 'sa', 'bgc'
|
||||
]
|
||||
if lang in latin_lang:
|
||||
lang = "latin"
|
||||
elif lang in arabic_lang:
|
||||
lang = "arabic"
|
||||
elif lang in cyrillic_lang:
|
||||
lang = "cyrillic"
|
||||
elif lang in devanagari_lang:
|
||||
lang = "devanagari"
|
||||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
if lang == "ch":
|
||||
det_lang = "ch"
|
||||
else:
|
||||
det_lang = "en"
|
||||
use_inner_dict = False
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
use_inner_dict = True
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
self.use_angle_cls = params.use_angle_cls
|
||||
lang, det_lang = parse_lang(params.lang)
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'det', det_lang)
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'rec', lang)
|
||||
if postprocess_params.cls_model_dir is None:
|
||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
print(postprocess_params)
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
|
||||
model_urls['cls'])
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir,
|
||||
model_urls['det'][det_lang])
|
||||
maybe_download(postprocess_params.rec_model_dir,
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.cls_model_dir, cls_url)
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
sys.exit(0)
|
||||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
if params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||
sys.exit(0)
|
||||
if use_inner_dict:
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
|
||||
|
||||
print(params)
|
||||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
super().__init__(params)
|
||||
|
||||
def ocr(self, img, det=True, rec=True, cls=False):
|
||||
def ocr(self, img, det=True, rec=True, cls=True):
|
||||
"""
|
||||
ocr with paddleocr
|
||||
args:
|
||||
|
@ -358,9 +244,7 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if isinstance(img, list) and det == True:
|
||||
logger.error('When input a list of images, det must be false')
|
||||
exit(0)
|
||||
if cls == False:
|
||||
self.use_angle_cls = False
|
||||
elif cls == True and self.use_angle_cls == False:
|
||||
if cls == True and self.use_angle_cls == False:
|
||||
logger.warning(
|
||||
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
|
||||
)
|
||||
|
@ -382,7 +266,7 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
if det and rec:
|
||||
dt_boxes, rec_res = self.__call__(img)
|
||||
dt_boxes, rec_res = self.__call__(img, cls)
|
||||
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
|
||||
elif det and not rec:
|
||||
dt_boxes, elapse = self.text_detector(img)
|
||||
|
@ -392,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
else:
|
||||
if not isinstance(img, list):
|
||||
img = [img]
|
||||
if self.use_angle_cls:
|
||||
if self.use_angle_cls and cls:
|
||||
img, cls_res, elapse = self.text_classifier(img)
|
||||
if not rec:
|
||||
return cls_res
|
||||
|
@ -400,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
return rec_res
|
||||
|
||||
|
||||
class PPStructure(OCRSystem):
|
||||
def __init__(self, **kwargs):
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
lang, det_lang = parse_lang(params.lang)
|
||||
|
||||
# init model dir
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
|
||||
model_urls['table']['url'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.table_model_dir, table_url)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
|
||||
if params.table_char_dict_path is None:
|
||||
params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
|
||||
|
||||
print(params)
|
||||
super().__init__(params)
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
download_with_progressbar(img, 'tmp.jpg')
|
||||
img = 'tmp.jpg'
|
||||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
with open(image_file, 'rb') as f:
|
||||
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
res = super().__call__(img)
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
if image_dir.startswith('http'):
|
||||
if is_link(image_dir):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
|
@ -412,14 +349,29 @@ def main():
|
|||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
if args.type == 'ocr':
|
||||
engine = PaddleOCR(**(args.__dict__))
|
||||
elif args.type == 'structure':
|
||||
engine = PPStructure(**(args.__dict__))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ocr_engine = PaddleOCR(**(args.__dict__))
|
||||
for img_path in image_file_list:
|
||||
img_name = os.path.basename(img_path).split('.')[0]
|
||||
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||
result = ocr_engine.ocr(img_path,
|
||||
if args.type == 'ocr':
|
||||
result = engine.ocr(img_path,
|
||||
det=args.det,
|
||||
rec=args.rec,
|
||||
cls=args.use_angle_cls)
|
||||
if result is not None:
|
||||
for line in result:
|
||||
logger.info(line)
|
||||
if result is not None:
|
||||
for line in result:
|
||||
logger.info(line)
|
||||
elif args.type == 'structure':
|
||||
result = engine(img_path)
|
||||
save_structure_res(result, args.output, img_name)
|
||||
|
||||
for item in result:
|
||||
item.pop('img')
|
||||
logger.info(item)
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
|
|||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
|
|
@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
|
|||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from ppocr.data.imaug.iaa_augment import IaaAugment
|
||||
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
||||
from tools.infer.utility import get_rotate_crop_image
|
||||
|
||||
|
||||
class CopyPaste(object):
|
||||
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
||||
self.ext_data_num = 1
|
||||
self.objects_paste_ratio = objects_paste_ratio
|
||||
self.limit_paste = limit_paste
|
||||
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
|
||||
self.aug = IaaAugment(augmenter_args)
|
||||
|
||||
def __call__(self, data):
|
||||
src_img = data['image']
|
||||
src_polys = data['polys'].tolist()
|
||||
src_ignores = data['ignore_tags'].tolist()
|
||||
ext_data = data['ext_data'][0]
|
||||
ext_image = ext_data['image']
|
||||
ext_polys = ext_data['polys']
|
||||
ext_ignores = ext_data['ignore_tags']
|
||||
|
||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||
select_num = max(
|
||||
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
||||
|
||||
random.shuffle(indexs)
|
||||
select_idxs = indexs[:select_num]
|
||||
select_polys = ext_polys[select_idxs]
|
||||
select_ignores = ext_ignores[select_idxs]
|
||||
|
||||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||
for poly, tag in zip(select_polys, select_ignores):
|
||||
box_img = get_rotate_crop_image(ext_image, poly)
|
||||
|
||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||
if box is not None:
|
||||
src_polys.append(box)
|
||||
src_ignores.append(tag)
|
||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||
h, w = src_img.shape[:2]
|
||||
src_polys = np.array(src_polys)
|
||||
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
||||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||
data['image'] = src_img
|
||||
data['polys'] = src_polys
|
||||
data['ignore_tags'] = np.array(src_ignores)
|
||||
return data
|
||||
|
||||
def paste_img(self, src_img, box_img, src_polys):
|
||||
box_img_pil = Image.fromarray(box_img).convert('RGBA')
|
||||
src_w, src_h = src_img.size
|
||||
box_w, box_h = box_img_pil.size
|
||||
|
||||
angle = np.random.randint(0, 360)
|
||||
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
||||
box = rotate_bbox(box_img, box, angle)[0]
|
||||
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
||||
box_w, box_h = box_img_pil.width, box_img_pil.height
|
||||
if src_w - box_w < 0 or src_h - box_h < 0:
|
||||
return src_img, None
|
||||
|
||||
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
|
||||
src_h - box_h)
|
||||
if paste_x is None:
|
||||
return src_img, None
|
||||
box[:, 0] += paste_x
|
||||
box[:, 1] += paste_y
|
||||
r, g, b, A = box_img_pil.split()
|
||||
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
||||
|
||||
return src_img, box
|
||||
|
||||
def select_coord(self, src_polys, box, endx, endy):
|
||||
if self.limit_paste:
|
||||
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
|
||||
), box[:, 0].max(), box[:, 1].max()
|
||||
for _ in range(50):
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
xmin1 = xmin + paste_x
|
||||
xmax1 = xmax + paste_x
|
||||
ymin1 = ymin + paste_y
|
||||
ymax1 = ymax + paste_y
|
||||
|
||||
num_poly_in_rect = 0
|
||||
for poly in src_polys:
|
||||
if not is_poly_outside_rect(poly, xmin1, ymin1,
|
||||
xmax1 - xmin1, ymax1 - ymin1):
|
||||
num_poly_in_rect += 1
|
||||
break
|
||||
if num_poly_in_rect == 0:
|
||||
return paste_x, paste_y
|
||||
return None, None
|
||||
else:
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
return paste_x, paste_y
|
||||
|
||||
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
|
||||
def rotate_bbox(img, text_polys, angle, scale=1):
|
||||
"""
|
||||
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
||||
Args:
|
||||
img: np.ndarray
|
||||
text_polys: np.ndarray N*4*2
|
||||
angle: int
|
||||
scale: int
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
||||
rangle = np.deg2rad(angle)
|
||||
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
|
||||
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
|
||||
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
||||
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
||||
rot_mat[0, 2] += rot_move[0]
|
||||
rot_mat[1, 2] += rot_move[1]
|
||||
|
||||
# ---------------------- rotate box ----------------------
|
||||
rot_text_polys = list()
|
||||
for bbox in text_polys:
|
||||
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
||||
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
||||
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
||||
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
||||
rot_text_polys.append([point1, point2, point3, point4])
|
||||
return np.array(rot_text_polys, dtype=np.float32)
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GenTableMask(object):
|
||||
""" gen table mask """
|
||||
|
||||
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
||||
self.shrink_h_max = 5
|
||||
self.shrink_w_max = 5
|
||||
self.mask_type = mask_type
|
||||
|
||||
def projection(self, erosion, h, w, spilt_threshold=0):
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
return box_list, projection_map
|
||||
|
||||
def projection_cx(self, box_img):
|
||||
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
||||
h, w = box_gray_img.shape
|
||||
# 灰度图片进行二值化处理
|
||||
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
||||
# 纵向腐蚀
|
||||
if h < w:
|
||||
kernel = np.ones((2, 1), np.uint8)
|
||||
erode = cv2.erode(thresh1, kernel, iterations=1)
|
||||
else:
|
||||
erode = thresh1
|
||||
# 水平膨胀
|
||||
kernel = np.ones((1, 5), np.uint8)
|
||||
erosion = cv2.dilate(erode, kernel, iterations=1)
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
spilt_threshold = 0
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
split_bbox_list = []
|
||||
if len(box_list) > 1:
|
||||
for i, (h_start, h_end) in enumerate(box_list):
|
||||
if i == 0:
|
||||
h_start = 0
|
||||
if i == len(box_list):
|
||||
h_end = h
|
||||
word_img = erosion[h_start:h_end + 1, :]
|
||||
word_h, word_w = word_img.shape
|
||||
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
|
||||
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
||||
if h_start > 0:
|
||||
h_start -= 1
|
||||
h_end += 1
|
||||
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
||||
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
||||
else:
|
||||
split_bbox_list.append([0, 0, w, h])
|
||||
return split_bbox_list
|
||||
|
||||
def shrink_bbox(self, bbox):
|
||||
left, top, right, bottom = bbox
|
||||
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
||||
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
||||
left_new = left + sh_w
|
||||
right_new = right - sh_w
|
||||
top_new = top + sh_h
|
||||
bottom_new = bottom - sh_h
|
||||
if left_new >= right_new:
|
||||
left_new = left
|
||||
right_new = right
|
||||
if top_new >= bottom_new:
|
||||
top_new = top
|
||||
bottom_new = bottom
|
||||
return [left_new, top_new, right_new, bottom_new]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
cells = data['cells']
|
||||
height, width = img.shape[0:2]
|
||||
if self.mask_type == 1:
|
||||
mask_img = np.zeros((height, width), dtype=np.float32)
|
||||
else:
|
||||
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
left, top, right, bottom = bbox
|
||||
box_img = img[top:bottom, left:right, :].copy()
|
||||
split_bbox_list = self.projection_cx(box_img)
|
||||
for sno in range(len(split_bbox_list)):
|
||||
split_bbox_list[sno][0] += left
|
||||
split_bbox_list[sno][1] += top
|
||||
split_bbox_list[sno][2] += left
|
||||
split_bbox_list[sno][3] += top
|
||||
|
||||
for sno in range(len(split_bbox_list)):
|
||||
left, top, right, bottom = split_bbox_list[sno]
|
||||
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
|
||||
if self.mask_type == 1:
|
||||
mask_img[top:bottom, left:right] = 1.0
|
||||
data['mask_img'] = mask_img
|
||||
else:
|
||||
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
||||
data['image'] = mask_img
|
||||
return data
|
||||
|
||||
class ResizeTableImage(object):
|
||||
def __init__(self, max_len, **kwargs):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
|
||||
def get_img_bbox(self, cells):
|
||||
bbox_list = []
|
||||
if len(cells) == 0:
|
||||
return bbox_list
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
bbox_list.append(bbox)
|
||||
return bbox_list
|
||||
|
||||
def resize_img_table(self, img, bbox_list, max_len):
|
||||
height, width = img.shape[0:2]
|
||||
ratio = max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
img_new = cv2.resize(img, (resize_w, resize_h))
|
||||
bbox_list_new = []
|
||||
for bno in range(len(bbox_list)):
|
||||
left, top, right, bottom = bbox_list[bno].copy()
|
||||
left = int(left * ratio)
|
||||
top = int(top * ratio)
|
||||
right = int(right * ratio)
|
||||
bottom = int(bottom * ratio)
|
||||
bbox_list_new.append([left, top, right, bottom])
|
||||
return img_new, bbox_list_new
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'cells' not in data:
|
||||
cells = []
|
||||
else:
|
||||
cells = data['cells']
|
||||
bbox_list = self.get_img_bbox(cells)
|
||||
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
||||
data['image'] = img_new
|
||||
cell_num = len(cells)
|
||||
bno = 0
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in data['cells'][cno]:
|
||||
data['cells'][cno]['bbox'] = bbox_list_new[bno]
|
||||
bno += 1
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
class PaddingTableImage(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
max_len = data['max_len']
|
||||
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data['image'] = padding_img
|
||||
return data
|
||||
|
|
@ -19,6 +19,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import numpy as np
|
||||
import string
|
||||
import json
|
||||
|
||||
|
||||
class ClsLabelEncode(object):
|
||||
|
@ -39,7 +40,6 @@ class DetLabelEncode(object):
|
|||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
import json
|
||||
label = data['label']
|
||||
label = json.loads(label)
|
||||
nBox = len(label)
|
||||
|
@ -53,6 +53,8 @@ class DetLabelEncode(object):
|
|||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
if len(boxes) == 0:
|
||||
return None
|
||||
boxes = self.expand_points_num(boxes)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
|
@ -351,3 +353,171 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
|||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelEncode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
span_weight=1.0,
|
||||
**kwargs):
|
||||
self.max_text_length = max_text_length
|
||||
self.max_elem_length = max_elem_length
|
||||
self.max_cell_num = max_cell_num
|
||||
list_character, list_elem = self.load_char_elem_dict(
|
||||
character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_elem[elem] = i
|
||||
self.span_weight = span_weight
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\r\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\r\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def get_span_idx_list(self):
|
||||
span_idx_list = []
|
||||
for elem in self.dict_elem:
|
||||
if 'span' in elem:
|
||||
span_idx_list.append(self.dict_elem[elem])
|
||||
return span_idx_list
|
||||
|
||||
def __call__(self, data):
|
||||
cells = data['cells']
|
||||
structure = data['structure']['tokens']
|
||||
structure = self.encode(structure, 'elem')
|
||||
if structure is None:
|
||||
return None
|
||||
elem_num = len(structure)
|
||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
|
||||
)
|
||||
structure = np.array(structure)
|
||||
data['structure'] = structure
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
span_idx_list = self.get_span_idx_list()
|
||||
td_idx_list = np.logical_or(structure == elem_char_idx1,
|
||||
structure == elem_char_idx2)
|
||||
td_idx_list = np.where(td_idx_list)[0]
|
||||
|
||||
structure_mask = np.ones(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||
bbox_list_mask = np.zeros(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
img_height, img_width, img_ch = data['image'].shape
|
||||
if len(span_idx_list) > 0:
|
||||
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
||||
for cno in range(len(cells)):
|
||||
if 'bbox' in cells[cno]:
|
||||
bbox = cells[cno]['bbox'].copy()
|
||||
bbox[0] = bbox[0] * 1.0 / img_width
|
||||
bbox[1] = bbox[1] * 1.0 / img_height
|
||||
bbox[2] = bbox[2] * 1.0 / img_width
|
||||
bbox[3] = bbox[3] * 1.0 / img_height
|
||||
td_idx = td_idx_list[cno]
|
||||
bbox_list[td_idx] = bbox
|
||||
bbox_list_mask[td_idx] = 1.0
|
||||
cand_span_idx = td_idx + 1
|
||||
if cand_span_idx < (self.max_elem_length + 2):
|
||||
if structure[cand_span_idx] in span_idx_list:
|
||||
structure_mask[cand_span_idx] = span_weight
|
||||
|
||||
data['bbox_list'] = bbox_list
|
||||
data['bbox_list_mask'] = bbox_list_mask
|
||||
data['structure_mask'] = structure_mask
|
||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
data['sp_tokens'] = np.array([
|
||||
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
|
||||
elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num, elem_num
|
||||
])
|
||||
return data
|
||||
|
||||
def encode(self, text, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
max_len = self.max_text_length
|
||||
current_dict = self.dict_character
|
||||
else:
|
||||
max_len = self.max_elem_length
|
||||
current_dict = self.dict_elem
|
||||
if len(text) > max_len:
|
||||
return None
|
||||
if len(text) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in current_dict:
|
||||
return None
|
||||
text_list.append(current_dict[char])
|
||||
if len(text_list) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
return text_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_character[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_character[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_elem[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_elem[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
|
@ -81,7 +81,7 @@ class NormalizeImage(object):
|
|||
assert isinstance(img,
|
||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||
data['image'] = (
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
img.astype('float32') * self.scale - self.mean) / self.std
|
||||
return data
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ class DetResizeForTest(object):
|
|||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, _ = img.shape
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
|
@ -174,7 +174,7 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
else:
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
|
@ -182,6 +182,10 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h,w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from paddle.io import Dataset
|
||||
import json
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class PubTabDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PubTabDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_path = dataset_config.pop('label_file_path')
|
||||
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
self.do_hard_select = False
|
||||
if 'hard_select' in loader_config:
|
||||
self.do_hard_select = loader_config['hard_select']
|
||||
self.hard_prob = loader_config['hard_prob']
|
||||
if self.do_hard_select:
|
||||
self.img_select_prob = self.load_hard_select_prob()
|
||||
self.table_select_type = None
|
||||
if 'table_select_type' in loader_config:
|
||||
self.table_select_type = loader_config['table_select_type']
|
||||
self.table_select_prob = loader_config['table_select_prob']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_path)
|
||||
with open(label_file_path, "rb") as f:
|
||||
self.data_lines = f.readlines()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
data_line = self.data_lines[idx]
|
||||
data_line = data_line.decode('utf-8').strip("\n")
|
||||
info = json.loads(data_line)
|
||||
file_name = info['filename']
|
||||
select_flag = True
|
||||
if self.do_hard_select:
|
||||
prob = self.img_select_prob[file_name]
|
||||
if prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if self.table_select_type:
|
||||
structure = info['html']['structure']['tokens'].copy()
|
||||
structure_str = ''.join(structure)
|
||||
table_type = "simple"
|
||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||
table_type = "complex"
|
||||
if table_type == "complex":
|
||||
if self.table_select_prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if select_flag:
|
||||
cells = info['html']['cells'].copy()
|
||||
structure = info['html']['structure'].copy()
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
else:
|
||||
outs = None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
|
@ -69,12 +69,42 @@ class SimpleDataSet(Dataset):
|
|||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def get_ext_data(self):
|
||||
ext_data_num = 0
|
||||
for op in self.ops:
|
||||
if hasattr(op, 'ext_data_num'):
|
||||
ext_data_num = getattr(op, 'ext_data_num')
|
||||
break
|
||||
load_data_ops = self.ops[:2]
|
||||
ext_data = []
|
||||
|
||||
while len(ext_data) < ext_data_num:
|
||||
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
||||
))]
|
||||
data_line = self.data_lines[file_idx]
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data = transform(data, load_data_ops)
|
||||
if data is None:
|
||||
continue
|
||||
ext_data.append(data)
|
||||
return ext_data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
|
@ -84,6 +114,7 @@ class SimpleDataSet(Dataset):
|
|||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
|
|
|
@ -13,28 +13,39 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
|
||||
# basic loss function
|
||||
from .basic_loss import DistanceLoss
|
||||
|
||||
# combined loss function
|
||||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
def build_loss(config):
|
||||
# det loss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_east_loss import EASTLoss
|
||||
from .det_sast_loss import SASTLoss
|
||||
|
||||
# rec loss
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss']
|
||||
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,128 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from paddle.nn import L1Loss
|
||||
from paddle.nn import MSELoss as L2Loss
|
||||
from paddle.nn import SmoothL1Loss
|
||||
|
||||
|
||||
class CELoss(nn.Layer):
|
||||
def __init__(self, epsilon=None):
|
||||
super().__init__()
|
||||
if epsilon is not None and (epsilon <= 0 or epsilon >= 1):
|
||||
epsilon = None
|
||||
self.epsilon = epsilon
|
||||
|
||||
def _labelsmoothing(self, target, class_num):
|
||||
if target.shape[-1] != class_num:
|
||||
one_hot_target = F.one_hot(target, class_num)
|
||||
else:
|
||||
one_hot_target = target
|
||||
soft_target = F.label_smooth(one_hot_target, epsilon=self.epsilon)
|
||||
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
|
||||
return soft_target
|
||||
|
||||
def forward(self, x, label):
|
||||
loss_dict = {}
|
||||
if self.epsilon is not None:
|
||||
class_num = x.shape[-1]
|
||||
label = self._labelsmoothing(label, class_num)
|
||||
x = -F.log_softmax(x, axis=-1)
|
||||
loss = paddle.sum(x * label, axis=-1)
|
||||
else:
|
||||
if label.shape[-1] == x.shape[-1]:
|
||||
label = F.softmax(label, axis=-1)
|
||||
soft_label = True
|
||||
else:
|
||||
soft_label = False
|
||||
loss = F.cross_entropy(x, label=label, soft_label=soft_label)
|
||||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
|
||||
return loss
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
"""
|
||||
|
||||
def __init__(self, act=None):
|
||||
super().__init__()
|
||||
if act is not None:
|
||||
assert act in ["softmax", "sigmoid"]
|
||||
if act == "softmax":
|
||||
self.act = nn.Softmax(axis=-1)
|
||||
elif act == "sigmoid":
|
||||
self.act = nn.Sigmoid()
|
||||
else:
|
||||
self.act = None
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
if len(out1.shape) < 2:
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
class DistanceLoss(nn.Layer):
|
||||
"""
|
||||
DistanceLoss:
|
||||
mode: loss mode
|
||||
"""
|
||||
|
||||
def __init__(self, mode="l2", **kargs):
|
||||
super().__init__()
|
||||
assert mode in ["l1", "l2", "smooth_l1"]
|
||||
if mode == "l1":
|
||||
self.loss_func = nn.L1Loss(**kargs)
|
||||
elif mode == "l2":
|
||||
self.loss_func = nn.MSELoss(**kargs)
|
||||
elif mode == "smooth_l1":
|
||||
self.loss_func = nn.SmoothL1Loss(**kargs)
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.loss_func(x, y)
|
|
@ -24,7 +24,7 @@ class ClsLoss(nn.Layer):
|
|||
super(ClsLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
def forward(self, predicts, batch):
|
||||
label = batch[1]
|
||||
loss = self.loss_func(input=predicts, label=label)
|
||||
return {'loss': loss}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
"""
|
||||
CombinedLoss:
|
||||
a combionation of loss function
|
||||
"""
|
||||
|
||||
def __init__(self, loss_config_list=None):
|
||||
super().__init__()
|
||||
self.loss_func = []
|
||||
self.loss_weight = []
|
||||
assert isinstance(loss_config_list, list), (
|
||||
'operator config should be a list')
|
||||
for config in loss_config_list:
|
||||
assert isinstance(config,
|
||||
dict) and len(config) == 1, "yaml format error"
|
||||
name = list(config)[0]
|
||||
param = config[name]
|
||||
assert "weight" in param, "weight must be in param, but param just contains {}".format(
|
||||
param.keys())
|
||||
self.loss_weight.append(param.pop("weight"))
|
||||
self.loss_func.append(eval(name)(**param))
|
||||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
loss_all = 0.
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
|
@ -0,0 +1,270 @@
|
|||
#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
else:
|
||||
loss_dict["loss"] = 0.
|
||||
for k, value in loss_dict.items():
|
||||
if k == "loss":
|
||||
continue
|
||||
else:
|
||||
loss_dict["loss"] += value
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
|
||||
if self.maps_name is None:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}".format(key, pair[0], pair[1],
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationCTCLoss(CTCLoss):
|
||||
def __init__(self, model_name_list=[], key=None, name="loss_ctc"):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.key = key
|
||||
self.name = name
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, model_name,
|
||||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
|
||||
if isinstance(loss, dict):
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
continue
|
||||
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||
loss_dict[name] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="dila_dbloss"):
|
||||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
stu_outs = predicts[pair[0]]
|
||||
tch_outs = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
stu_preds = stu_outs[self.key]
|
||||
tch_preds = tch_outs[self.key]
|
||||
|
||||
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||
|
||||
# dilation to teacher prediction
|
||||
dilation_w = np.array([[1, 1], [1, 1]])
|
||||
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||
for i in range(th_shrink_maps.shape[0]):
|
||||
dilate_maps[i] = cv2.dilate(
|
||||
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||
1:]
|
||||
|
||||
# calculate the shrink map loss
|
||||
bce_loss = self.alpha * self.bce_loss(
|
||||
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||
label_shrink_mask)
|
||||
|
||||
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
mode="l2",
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
name="loss_distance",
|
||||
**kargs):
|
||||
super().__init__(mode=mode, **kargs)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name + "_l2"
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
out1 = predicts[pair[0]]
|
||||
out2 = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}".format(self.name, key, idx)] = loss[
|
||||
key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
|
||||
idx)] = loss
|
||||
return loss_dict
|
|
@ -25,7 +25,7 @@ class CTCLoss(nn.Layer):
|
|||
super(CTCLoss, self).__init__()
|
||||
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
|
||||
|
||||
def __call__(self, predicts, batch):
|
||||
def forward(self, predicts, batch):
|
||||
predicts = predicts.transpose((1, 0, 2))
|
||||
N, B, _ = predicts.shape
|
||||
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import fluid
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
|
||||
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
|
||||
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
if len(batch) == 6:
|
||||
structure_mask = batch[5].astype("int64")
|
||||
structure_mask = structure_mask[:, 1:]
|
||||
structure_mask = paddle.reshape(structure_mask, [-1])
|
||||
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
|
||||
structure_targets = paddle.reshape(structure_targets, [-1])
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
if len(batch) == 6:
|
||||
structure_loss = structure_loss * structure_mask
|
||||
|
||||
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[4].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
|
|
@ -19,20 +19,23 @@ from __future__ import unicode_literals
|
|||
|
||||
import copy
|
||||
|
||||
__all__ = ['build_metric']
|
||||
__all__ = ["build_metric"]
|
||||
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
from .table_metric import TableMetric
|
||||
|
||||
def build_metric(config):
|
||||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'metric only support {}'.format(support_dict))
|
||||
"metric only support {}".format(support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -55,6 +55,7 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import copy
|
||||
|
||||
from .rec_metric import RecMetric
|
||||
from .det_metric import DetMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .cls_metric import ClsMetric
|
||||
|
||||
|
||||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
self.main_indicator = main_indicator
|
||||
self.base_metric_name = base_metric_name
|
||||
self.kwargs = kwargs
|
||||
self.metrics = None
|
||||
|
||||
def _init_metrcis(self, preds):
|
||||
self.metrics = dict()
|
||||
mod = importlib.import_module(__name__)
|
||||
for key in preds:
|
||||
self.metrics[key] = getattr(mod, self.base_metric_name)(
|
||||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
'norm_edit_dis': 0,
|
||||
}
|
||||
"""
|
||||
output = dict()
|
||||
for key in self.metrics:
|
||||
metric = self.metrics[key].get_metric()
|
||||
# main indicator
|
||||
if key == self.key:
|
||||
output.update(metric)
|
||||
else:
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
|
||||
def reset(self):
|
||||
for key in self.metrics:
|
||||
self.metrics[key].reset()
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
class TableMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred, batch, *args, **kwargs):
|
||||
structure_probs = pred['structure_probs'].numpy()
|
||||
structure_labels = batch[1]
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
structure_probs = np.argmax(structure_probs, axis=2)
|
||||
structure_labels = structure_labels[:, 1:]
|
||||
batch_size = structure_probs.shape[0]
|
||||
for bno in range(batch_size):
|
||||
all_num += 1
|
||||
if (structure_probs[bno] == structure_labels[bno]).all():
|
||||
correct_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {
|
||||
'acc': correct_num * 1.0 / all_num,
|
||||
}
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
|
@ -13,12 +13,20 @@
|
|||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
|
||||
from .base_model import BaseModel
|
||||
from .distillation_model import DistillationModel
|
||||
|
||||
__all__ = ['build_model']
|
||||
|
||||
|
||||
def build_model(config):
|
||||
from .base_model import BaseModel
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_class = BaseModel(config)
|
||||
return module_class
|
||||
if not "name" in config:
|
||||
arch = BaseModel(config)
|
||||
else:
|
||||
name = config.pop("name")
|
||||
mod = importlib.import_module(__name__)
|
||||
arch = getattr(mod, name)(config)
|
||||
return arch
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -32,7 +32,6 @@ class BaseModel(nn.Layer):
|
|||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super(BaseModel, self).__init__()
|
||||
|
||||
in_channels = config.get('in_channels', 3)
|
||||
model_type = config['model_type']
|
||||
# build transfrom,
|
||||
|
@ -68,14 +67,23 @@ class BaseModel(nn.Layer):
|
|||
config["Head"]['in_channels'] = in_channels
|
||||
self.head = build_head(config["Head"])
|
||||
|
||||
self.return_all_feats = config.get("return_all_feats", False)
|
||||
|
||||
def forward(self, x, data=None):
|
||||
y = dict()
|
||||
if self.use_transform:
|
||||
x = self.transform(x)
|
||||
x = self.backbone(x)
|
||||
y["backbone_out"] = x
|
||||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
y["neck_out"] = x
|
||||
x = self.head(x, targets=data)
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
return x
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
else:
|
||||
return x
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
||||
class DistillationModel(nn.Layer):
|
||||
def __init__(self, config):
|
||||
"""
|
||||
the module for OCR distillation.
|
||||
args:
|
||||
config (dict): the super parameters for module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.model_list = []
|
||||
self.model_name_list = []
|
||||
for key in config["Models"]:
|
||||
model_config = config["Models"][key]
|
||||
freeze_params = False
|
||||
pretrained = None
|
||||
if "freeze_params" in model_config:
|
||||
freeze_params = model_config.pop("freeze_params")
|
||||
if "pretrained" in model_config:
|
||||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
self.model_list.append(self.add_sublayer(key, model))
|
||||
self.model_name_list.append(key)
|
||||
|
||||
def forward(self, x):
|
||||
result_dict = dict()
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
result_dict[model_name] = self.model_list[idx](x)
|
||||
return result_dict
|
|
@ -12,29 +12,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['build_backbone']
|
||||
__all__ = ["build_backbone"]
|
||||
|
||||
|
||||
def build_backbone(config, model_type):
|
||||
if model_type == 'det':
|
||||
if model_type == "det":
|
||||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||
elif model_type == 'rec' or model_type == 'cls':
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
|
||||
elif model_type == "rec" or model_type == "cls":
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
elif model_type == 'e2e':
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
support_dict = ["ResNet"]
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'when model typs is {}, backbone only support {}'.format(model_type,
|
||||
"when model typs is {}, backbone only support {}".format(model_type,
|
||||
support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -102,8 +102,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
act='hardswish')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
|
@ -125,8 +124,7 @@ class MobileNetV3(nn.Layer):
|
|||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2)))
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
|
@ -138,8 +136,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last'))
|
||||
act='hardswish'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
|
@ -163,8 +160,7 @@ class ConvBNLayer(nn.Layer):
|
|||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
act=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
|
@ -175,16 +171,9 @@ class ConvBNLayer(nn.Layer):
|
|||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
self.bn = nn.BatchNorm(num_channels=out_channels, act=None)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
|
@ -209,8 +198,7 @@ class ResidualUnit(nn.Layer):
|
|||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name=''):
|
||||
act=None):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
@ -222,8 +210,7 @@ class ResidualUnit(nn.Layer):
|
|||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand")
|
||||
act=act)
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
|
@ -232,10 +219,9 @@ class ResidualUnit(nn.Layer):
|
|||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise")
|
||||
act=act)
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||
self.mid_se = SEModule(mid_channels)
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
|
@ -243,8 +229,7 @@ class ResidualUnit(nn.Layer):
|
|||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear")
|
||||
act=None)
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
|
@ -258,7 +243,7 @@ class ResidualUnit(nn.Layer):
|
|||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
def __init__(self, in_channels, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
|
@ -266,17 +251,13 @@ class SEModule(nn.Layer):
|
|||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
padding=0)
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
|
|
|
@ -96,8 +96,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
act='hardswish')
|
||||
i = 0
|
||||
block_list = []
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
|
@ -110,8 +109,7 @@ class MobileNetV3(nn.Layer):
|
|||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name='conv' + str(i + 2)))
|
||||
act=nl))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
self.blocks = nn.Sequential(*block_list)
|
||||
|
@ -124,8 +122,7 @@ class MobileNetV3(nn.Layer):
|
|||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last')
|
||||
act='hardswish')
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr, reshape, transpose, concat, split
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
from paddle.nn.functional import hardswish, hardsigmoid
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='hard_swish'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
dw_size=3,
|
||||
padding=1,
|
||||
use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale))
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Layer):
|
||||
def __init__(self, in_channels=3, scale=0.5, **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale),
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale),
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
return paddle.multiply(x=inputs, y=outputs)
|
|
@ -0,0 +1,287 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='large',
|
||||
scale=0.5,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
"""
|
||||
the MobilenetV3 backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for build network
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.disable_se = disable_se
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', 2],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hardswish', 2],
|
||||
[3, 200, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 480, 112, True, 'hardswish', 1],
|
||||
[3, 672, 112, True, 'hardswish', 1],
|
||||
[5, 672, 160, True, 'hardswish', 2],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 960
|
||||
elif model_name == "small":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 2],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hardswish', 2],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 120, 48, True, 'hardswish', 1],
|
||||
[5, 144, 48, True, 'hardswish', 1],
|
||||
[5, 288, 96, True, 'hardswish', 2],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 576
|
||||
else:
|
||||
raise NotImplementedError("mode[" + model_name +
|
||||
"_model] is not implemented!")
|
||||
|
||||
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||
assert scale in supported_scale, \
|
||||
"supported scale are {} but input scale is {}".format(supported_scale, scale)
|
||||
inplanes = 16
|
||||
# conv1
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=make_divisible(inplanes * scale),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
block_list = []
|
||||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
start_idx = 2 if model_name == 'large' else 0
|
||||
if s == 2 and i > start_idx:
|
||||
self.out_channels.append(inplanes)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
block_list = []
|
||||
block_list.append(
|
||||
ResidualUnit(
|
||||
in_channels=inplanes,
|
||||
mid_channels=make_divisible(scale * exp),
|
||||
out_channels=make_divisible(scale * c),
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2)))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
ConvBNLayer(
|
||||
in_channels=inplanes,
|
||||
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
out_list = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
out_list.append(x)
|
||||
return out_list
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = F.hardswish(x)
|
||||
else:
|
||||
print("The activation function({}) is selected incorrectly.".
|
||||
format(self.act))
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name=''):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand")
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise")
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||
return inputs * outputs
|
|
@ -0,0 +1,280 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512,
|
||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_2")
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_3")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
out = []
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -31,8 +31,10 @@ def build_head(config):
|
|||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead']
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
|
|||
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
x = self.pool(x)
|
||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = self.fc(x)
|
||||
|
|
|
@ -23,10 +23,10 @@ import paddle.nn.functional as F
|
|||
from paddle import ParamAttr
|
||||
|
||||
|
||||
def get_bias_attr(k, name):
|
||||
def get_bias_attr(k):
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = paddle.nn.initializer.Uniform(-stdv, stdv)
|
||||
bias_attr = ParamAttr(initializer=initializer, name=name + "_b_attr")
|
||||
bias_attr = ParamAttr(initializer=initializer)
|
||||
return bias_attr
|
||||
|
||||
|
||||
|
@ -38,18 +38,14 @@ class Head(nn.Layer):
|
|||
out_channels=in_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(name=name_list[0] + '.w_0'),
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=False)
|
||||
self.conv_bn1 = nn.BatchNorm(
|
||||
num_channels=in_channels // 4,
|
||||
param_attr=ParamAttr(
|
||||
name=name_list[1] + '.w_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0)),
|
||||
bias_attr=ParamAttr(
|
||||
name=name_list[1] + '.b_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1e-4)),
|
||||
moving_mean_name=name_list[1] + '.w_1',
|
||||
moving_variance_name=name_list[1] + '.w_2',
|
||||
act='relu')
|
||||
self.conv2 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
|
@ -57,19 +53,14 @@ class Head(nn.Layer):
|
|||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[2] + '.w_0',
|
||||
initializer=paddle.nn.initializer.KaimingUniform()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
|
||||
bias_attr=get_bias_attr(in_channels // 4))
|
||||
self.conv_bn2 = nn.BatchNorm(
|
||||
num_channels=in_channels // 4,
|
||||
param_attr=ParamAttr(
|
||||
name=name_list[3] + '.w_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1.0)),
|
||||
bias_attr=ParamAttr(
|
||||
name=name_list[3] + '.b_0',
|
||||
initializer=paddle.nn.initializer.Constant(value=1e-4)),
|
||||
moving_mean_name=name_list[3] + '.w_1',
|
||||
moving_variance_name=name_list[3] + '.w_2',
|
||||
act="relu")
|
||||
self.conv3 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
|
@ -77,10 +68,8 @@ class Head(nn.Layer):
|
|||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[4] + '.w_0',
|
||||
initializer=paddle.nn.initializer.KaimingUniform()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
|
||||
)
|
||||
bias_attr=get_bias_attr(in_channels // 4), )
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
|
@ -117,7 +106,7 @@ class DBHead(nn.Layer):
|
|||
def step_function(self, x, y):
|
||||
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
shrink_maps = self.binarize(x)
|
||||
if not self.training:
|
||||
return {'maps': shrink_maps}
|
||||
|
|
|
@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
|
|||
act=None,
|
||||
name="f_geo")
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_det = self.det_conv1(x)
|
||||
f_det = self.det_conv2(f_det)
|
||||
f_score = self.score_conv(f_det)
|
||||
|
|
|
@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
|
|||
self.head1 = SAST_Header1(in_channels)
|
||||
self.head2 = SAST_Header2(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score, f_border = self.head1(x)
|
||||
f_tvo, f_tco = self.head2(x)
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ class PGHead(nn.Layer):
|
|||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score = self.conv_f_score1(x)
|
||||
f_score = self.conv_f_score2(f_score)
|
||||
f_score = self.conv_f_score3(f_score)
|
||||
|
|
|
@ -23,32 +23,57 @@ from paddle import ParamAttr, nn
|
|||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def get_para_bias_attr(l2_decay, k, name):
|
||||
def get_para_bias_attr(l2_decay, k):
|
||||
regularizer = paddle.regularizer.L2Decay(l2_decay)
|
||||
stdv = 1.0 / math.sqrt(k * 1.0)
|
||||
initializer = nn.initializer.Uniform(-stdv, stdv)
|
||||
weight_attr = ParamAttr(
|
||||
regularizer=regularizer, initializer=initializer, name=name + "_w_attr")
|
||||
bias_attr = ParamAttr(
|
||||
regularizer=regularizer, initializer=initializer, name=name + "_b_attr")
|
||||
weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
|
||||
return [weight_attr, bias_attr]
|
||||
|
||||
|
||||
class CTCHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels, name='ctc_fc')
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr,
|
||||
name='ctc_fc')
|
||||
self.out_channels = out_channels
|
||||
if mid_channels is None:
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
else:
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1)
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
predicts = self.fc(x)
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=mid_channels)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
|
|
|
@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
|
|||
|
||||
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||
|
||||
def forward(self, inputs, others):
|
||||
def forward(self, inputs, targets=None):
|
||||
others = targets[-4:]
|
||||
encoder_word_pos = others[0]
|
||||
gsrm_word_pos = others[1]
|
||||
gsrm_slf_attn_bias1 = others[2]
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.elem_num = 30
|
||||
self.max_text_length = 100
|
||||
self.max_elem_length = 500
|
||||
self.max_cell_num = 500
|
||||
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||
self.loc_type = loc_type
|
||||
self.in_max_len = in_max_len
|
||||
|
||||
if self.loc_type == 1:
|
||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||
else:
|
||||
if self.in_max_len == 640:
|
||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
||||
elif self.in_max_len == 800:
|
||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
||||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
# if and else branch are both needed when you want to assign a variable
|
||||
# if you modify the var in just one branch, then the modification will not work.
|
||||
fea = inputs[-1]
|
||||
if len(fea.shape) == 3:
|
||||
pass
|
||||
else:
|
||||
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
batch_size = fea.shape[0]
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_elem_length+1):
|
||||
elem_onehots = self._char_to_onehot(
|
||||
structure[:, i], onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
structure_probs = None
|
||||
loc_preds = None
|
||||
elem_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||
i = 0
|
||||
while i < max_elem_length+1:
|
||||
elem_onehots = self._char_to_onehot(
|
||||
temp_elem, onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
structure_probs_step = self.structure_generator(outputs)
|
||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||
i += 1
|
||||
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
structure_probs = F.softmax(structure_probs)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||
|
||||
|
||||
class AttentionGRUCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionGRUCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
return cur_hidden, alpha
|
||||
|
||||
|
||||
class AttentionLSTM(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||
super(AttentionLSTM, self).__init__()
|
||||
self.input_size = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_classes = out_channels
|
||||
|
||||
self.attention_cell = AttentionLSTMCell(
|
||||
in_channels, hidden_size, out_channels, use_gru=False)
|
||||
self.generator = nn.Linear(hidden_size, out_channels)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||
(batch_size, self.hidden_size)))
|
||||
output_hiddens = []
|
||||
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
# one-hot vectors for a i-th char
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
|
||||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets, onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
probs_step = self.generator(hidden[0])
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
if probs is None:
|
||||
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||
else:
|
||||
probs = paddle.concat(
|
||||
[probs, paddle.unsqueeze(
|
||||
probs_step, axis=1)], axis=1)
|
||||
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
|
||||
targets = next_input
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class AttentionLSTMCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionLSTMCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
if not use_gru:
|
||||
self.rnn = nn.LSTMCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
else:
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
|
@ -21,7 +21,8 @@ def build_neck(config):
|
|||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
from .table_fpn import TableFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -32,61 +32,53 @@ class DBFPN(nn.Layer):
|
|||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_51.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_50.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_49.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_48.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_52.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_53.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_54.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_55.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class TableFPN(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(TableFPN, self).__init__()
|
||||
self.out_channels = 512
|
||||
weight_attr = paddle.nn.initializer.KaimingUniform()
|
||||
self.in2_conv = nn.Conv2D(
|
||||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride = 1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.fuse_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels * 4,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4, c5 = x
|
||||
|
||||
in5 = self.in5_conv(c5)
|
||||
in4 = self.in4_conv(c4)
|
||||
in3 = self.in3_conv(c3)
|
||||
in2 = self.in2_conv(c2)
|
||||
|
||||
out4 = in4 + F.upsample(
|
||||
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
|
||||
out3 = in3 + F.upsample(
|
||||
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
|
||||
out2 = in2 + F.upsample(
|
||||
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
|
||||
|
||||
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
|
||||
fuse_conv = self.fuse_conv(fuse) * 0.005
|
||||
return [c5 + fuse_conv]
|
|
@ -230,15 +230,8 @@ class GridGenerator(nn.Layer):
|
|||
def build_inv_delta_C_paddle(self, C):
|
||||
""" Return inv_delta_C which is needed to calculate T """
|
||||
F = self.F
|
||||
hat_C = paddle.zeros((F, F), dtype='float64') # F x F
|
||||
for i in range(0, F):
|
||||
for j in range(i, F):
|
||||
if i == j:
|
||||
hat_C[i, j] = 1
|
||||
else:
|
||||
r = paddle.norm(C[i] - C[j])
|
||||
hat_C[i, j] = r
|
||||
hat_C[j, i] = r
|
||||
hat_eye = paddle.eye(F, dtype='float64') # F x F
|
||||
hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
|
||||
hat_C = (hat_C**2) * paddle.log(hat_C)
|
||||
delta_C = paddle.concat( # F+3 x F+3
|
||||
[
|
||||
|
|
|
@ -21,18 +21,21 @@ import copy
|
|||
|
||||
__all__ = ['build_post_process']
|
||||
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
from .db_postprocess import DBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -187,3 +187,29 @@ class DBPostProcess(object):
|
|||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
|
||||
|
||||
class DistillationDBPostProcess(object):
|
||||
def __init__(self, model_name=["student"],
|
||||
key=None,
|
||||
thresh=0.3,
|
||||
box_thresh=0.6,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=1.5,
|
||||
use_dilation=False,
|
||||
score_mode="fast",
|
||||
**kwargs):
|
||||
self.model_name = model_name
|
||||
self.key = key
|
||||
self.post_process = DBPostProcess(thresh=thresh,
|
||||
box_thresh=box_thresh,
|
||||
max_candidates=max_candidates,
|
||||
unclip_ratio=unclip_ratio,
|
||||
use_dilation=use_dilation,
|
||||
score_mode=score_mode)
|
||||
|
||||
def __call__(self, predicts, shape_list):
|
||||
results = {}
|
||||
for k in self.model_name:
|
||||
results[k] = self.post_process(predicts[k], shape_list=shape_list)
|
||||
return results
|
||||
|
|
|
@ -44,16 +44,16 @@ class BaseRecLabelDecode(object):
|
|||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type in support_character_type:
|
||||
self.character_str = ""
|
||||
self.character_str = []
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||
character_type)
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
self.character_str.append(line)
|
||||
if use_space_char:
|
||||
self.character_str += " "
|
||||
self.character_str.append(" ")
|
||||
dict_character = list(self.character_str)
|
||||
|
||||
else:
|
||||
|
@ -125,6 +125,37 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class DistillationCTCLabelDecode(CTCLabelDecode):
|
||||
"""
|
||||
Convert
|
||||
Convert between text-label and text-index
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
model_name=["student"],
|
||||
key=None,
|
||||
**kwargs):
|
||||
super(DistillationCTCLabelDecode, self).__init__(
|
||||
character_dict_path, character_type, use_space_char)
|
||||
if not isinstance(model_name, list):
|
||||
model_name = [model_name]
|
||||
self.model_name = model_name
|
||||
|
||||
self.key = key
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
output = dict()
|
||||
for name in self.model_name:
|
||||
pred = preds[name]
|
||||
if self.key is not None:
|
||||
pred = pred[self.key]
|
||||
output[name] = super().__call__(pred, label=label, *args, **kwargs)
|
||||
return output
|
||||
|
||||
|
||||
class AttnLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -288,3 +319,138 @@ class SRNLabelDecode(BaseRecLabelDecode):
|
|||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelDecode(object):
|
||||
""" """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
**kwargs):
|
||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
self.dict_idx_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_idx_character[i] = char
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
self.dict_idx_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_idx_elem[i] = elem
|
||||
self.dict_elem[elem] = i
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
if isinstance(structure_probs,paddle.Tensor):
|
||||
structure_probs = structure_probs.numpy()
|
||||
if isinstance(loc_preds,paddle.Tensor):
|
||||
loc_preds = loc_preds.numpy()
|
||||
structure_idx = structure_probs.argmax(axis=2)
|
||||
structure_probs = structure_probs.max(axis=2)
|
||||
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
|
||||
structure_probs, 'elem')
|
||||
res_html_code_list = []
|
||||
res_loc_list = []
|
||||
batch_num = len(structure_str)
|
||||
for bno in range(batch_num):
|
||||
res_loc = []
|
||||
for sno in range(len(structure_str[bno])):
|
||||
text = structure_str[bno][sno]
|
||||
if text in ['<td>', '<td']:
|
||||
pos = structure_pos[bno][sno]
|
||||
res_loc.append(loc_preds[bno, pos])
|
||||
res_html_code = ''.join(structure_str[bno])
|
||||
res_loc = np.array(res_loc)
|
||||
res_html_code_list.append(res_html_code)
|
||||
res_loc_list.append(res_loc)
|
||||
return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
|
||||
'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
|
||||
|
||||
def decode(self, text_index, structure_probs, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
current_dict = self.dict_idx_character
|
||||
else:
|
||||
current_dict = self.dict_idx_elem
|
||||
ignored_tokens = self.get_ignored_tokens('elem')
|
||||
beg_idx, end_idx = ignored_tokens
|
||||
|
||||
result_list = []
|
||||
result_pos_list = []
|
||||
result_score_list = []
|
||||
result_elem_idx_list = []
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
elem_pos_list = []
|
||||
elem_idx_list = []
|
||||
score_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
tmp_elem_idx = int(text_index[batch_idx][idx])
|
||||
if idx > 0 and tmp_elem_idx == end_idx:
|
||||
break
|
||||
if tmp_elem_idx in ignored_tokens:
|
||||
continue
|
||||
|
||||
char_list.append(current_dict[tmp_elem_idx])
|
||||
elem_pos_list.append(idx)
|
||||
score_list.append(structure_probs[batch_idx, idx])
|
||||
elem_idx_list.append(tmp_elem_idx)
|
||||
result_list.append(char_list)
|
||||
result_pos_list.append(elem_pos_list)
|
||||
result_score_list.append(score_list)
|
||||
result_elem_idx_list.append(elem_idx_list)
|
||||
return result_list, result_pos_list, result_score_list, result_elem_idx_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_character[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_character[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = self.dict_elem[self.beg_str]
|
||||
elif beg_or_end == "end":
|
||||
idx = self.dict_elem[self.end_str]
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
|
@ -0,0 +1,277 @@
|
|||
←
|
||||
</overline>
|
||||
☆
|
||||
─
|
||||
α
|
||||
|
||||
|
||||
⋅
|
||||
$
|
||||
ω
|
||||
ψ
|
||||
χ
|
||||
(
|
||||
υ
|
||||
≥
|
||||
σ
|
||||
,
|
||||
ρ
|
||||
ε
|
||||
0
|
||||
■
|
||||
4
|
||||
8
|
||||
✗
|
||||
b
|
||||
<
|
||||
✓
|
||||
Ψ
|
||||
Ω
|
||||
€
|
||||
D
|
||||
3
|
||||
Π
|
||||
H
|
||||
║
|
||||
</strike>
|
||||
L
|
||||
Φ
|
||||
Χ
|
||||
θ
|
||||
P
|
||||
κ
|
||||
λ
|
||||
μ
|
||||
T
|
||||
ξ
|
||||
X
|
||||
β
|
||||
γ
|
||||
δ
|
||||
\
|
||||
ζ
|
||||
η
|
||||
`
|
||||
d
|
||||
<strike>
|
||||
h
|
||||
f
|
||||
l
|
||||
Θ
|
||||
p
|
||||
√
|
||||
t
|
||||
</sub>
|
||||
x
|
||||
Β
|
||||
Γ
|
||||
Δ
|
||||
|
|
||||
ǂ
|
||||
ɛ
|
||||
j
|
||||
̧
|
||||
➢
|
||||
|
||||
̌
|
||||
′
|
||||
«
|
||||
△
|
||||
▲
|
||||
#
|
||||
</b>
|
||||
'
|
||||
Ι
|
||||
+
|
||||
¶
|
||||
/
|
||||
▼
|
||||
⇑
|
||||
□
|
||||
·
|
||||
7
|
||||
▪
|
||||
;
|
||||
?
|
||||
➔
|
||||
∩
|
||||
C
|
||||
÷
|
||||
G
|
||||
⇒
|
||||
K
|
||||
<sup>
|
||||
O
|
||||
S
|
||||
С
|
||||
W
|
||||
Α
|
||||
[
|
||||
○
|
||||
_
|
||||
●
|
||||
‡
|
||||
c
|
||||
z
|
||||
g
|
||||
<i>
|
||||
o
|
||||
<sub>
|
||||
〈
|
||||
〉
|
||||
s
|
||||
⩽
|
||||
w
|
||||
φ
|
||||
ʹ
|
||||
{
|
||||
»
|
||||
∣
|
||||
̆
|
||||
e
|
||||
ˆ
|
||||
∈
|
||||
τ
|
||||
◆
|
||||
ι
|
||||
∅
|
||||
∆
|
||||
∙
|
||||
∘
|
||||
Ø
|
||||
ß
|
||||
✔
|
||||
∞
|
||||
∑
|
||||
−
|
||||
×
|
||||
◊
|
||||
∗
|
||||
∖
|
||||
˃
|
||||
˂
|
||||
∫
|
||||
"
|
||||
i
|
||||
&
|
||||
π
|
||||
↔
|
||||
*
|
||||
∥
|
||||
æ
|
||||
∧
|
||||
.
|
||||
⁄
|
||||
ø
|
||||
Q
|
||||
∼
|
||||
6
|
||||
⁎
|
||||
:
|
||||
★
|
||||
>
|
||||
a
|
||||
B
|
||||
≈
|
||||
F
|
||||
J
|
||||
̄
|
||||
N
|
||||
♯
|
||||
R
|
||||
V
|
||||
<overline>
|
||||
―
|
||||
Z
|
||||
♣
|
||||
^
|
||||
¤
|
||||
¥
|
||||
§
|
||||
<underline>
|
||||
¢
|
||||
£
|
||||
≦
|
||||
|
||||
≤
|
||||
‖
|
||||
Λ
|
||||
©
|
||||
n
|
||||
↓
|
||||
→
|
||||
↑
|
||||
r
|
||||
°
|
||||
±
|
||||
v
|
||||
<b>
|
||||
♂
|
||||
k
|
||||
♀
|
||||
~
|
||||
ᅟ
|
||||
̇
|
||||
@
|
||||
”
|
||||
♦
|
||||
ł
|
||||
®
|
||||
⊕
|
||||
„
|
||||
!
|
||||
</sup>
|
||||
%
|
||||
⇓
|
||||
)
|
||||
-
|
||||
1
|
||||
5
|
||||
9
|
||||
=
|
||||
А
|
||||
A
|
||||
‰
|
||||
⋆
|
||||
Σ
|
||||
E
|
||||
◦
|
||||
I
|
||||
※
|
||||
M
|
||||
m
|
||||
̨
|
||||
⩾
|
||||
†
|
||||
</i>
|
||||
•
|
||||
U
|
||||
Y
|
||||
|
||||
]
|
||||
̸
|
||||
2
|
||||
‐
|
||||
–
|
||||
‒
|
||||
̂
|
||||
—
|
||||
̀
|
||||
́
|
||||
’
|
||||
‘
|
||||
⋮
|
||||
⋯
|
||||
̊
|
||||
“
|
||||
̈
|
||||
≧
|
||||
q
|
||||
u
|
||||
ı
|
||||
y
|
||||
</underline>
|
||||
|
||||
̃
|
||||
}
|
||||
ν
|
|
@ -1,16 +1,16 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import argparse
|
||||
import json
|
||||
|
@ -31,7 +31,9 @@ def gen_det_label(root_path, input_dir, out_label):
|
|||
for label_file in os.listdir(input_dir):
|
||||
img_path = root_path + label_file[3:-4] + ".jpg"
|
||||
label = []
|
||||
with open(os.path.join(input_dir, label_file), 'r') as f:
|
||||
with open(
|
||||
os.path.join(input_dir, label_file), 'r',
|
||||
encoding='utf-8-sig') as f:
|
||||
for line in f.readlines():
|
||||
tmp = line.strip("\n\r").replace("\xef\xbb\xbf",
|
||||
"").split(',')
|
||||
|
|
|
@ -22,7 +22,7 @@ logger_initialized = {}
|
|||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name='root', log_file=None, log_level=logging.INFO):
|
||||
def get_logger(name='root', log_file=None, log_level=logging.DEBUG):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
logger = get_logger()
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
assert url.endswith('.tar'), 'Only supports tar compressed package'
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def is_link(s):
|
||||
return s is not None and s.startswith('http')
|
||||
|
||||
|
||||
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
|
||||
url = default_url
|
||||
if model_dir is None or is_link(model_dir):
|
||||
if is_link(model_dir):
|
||||
url = model_dir
|
||||
file_name = url.split('/')[-1][:-4]
|
||||
model_dir = default_model_dir
|
||||
model_dir = os.path.join(model_dir, file_name)
|
||||
return model_dir, url
|
|
@ -23,7 +23,9 @@ import six
|
|||
|
||||
import paddle
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
|
@ -42,44 +44,11 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = paddle.static.load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
weight_name = model_dict[key].name
|
||||
weight_name = weight_name.replace('binarize', '').replace(
|
||||
'thresh', '') # for DB
|
||||
if weight_name in pre_state_dict.keys():
|
||||
# logger.info('Load weight: {}, shape: {}'.format(
|
||||
# weight_name, pre_state_dict[weight_name].shape))
|
||||
if 'encoder_rnn' in key:
|
||||
# delete axis which is 1
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].squeeze()
|
||||
# change axis
|
||||
if len(pre_state_dict[weight_name].shape) > 1:
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
weight_name].transpose((1, 0))
|
||||
param_state_dict[key] = pre_state_dict[weight_name]
|
||||
else:
|
||||
param_state_dict[key] = model_dict[key]
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
param_state_dict = paddle.load(path + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
return
|
||||
|
||||
|
||||
def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
||||
def init_model(config, model, optimizer=None, lr_scheduler=None):
|
||||
"""
|
||||
load model from checkpoint or pretrained_model
|
||||
"""
|
||||
logger = get_logger()
|
||||
global_config = config['Global']
|
||||
checkpoints = global_config.get('checkpoints')
|
||||
pretrained_model = global_config.get('pretrained_model')
|
||||
|
@ -102,18 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
best_model_dict = states_dict.get('best_model_dict', {})
|
||||
if 'epoch' in states_dict:
|
||||
best_model_dict['start_epoch'] = states_dict['epoch'] + 1
|
||||
|
||||
logger.info("resume from {}".format(checkpoints))
|
||||
elif pretrained_model:
|
||||
load_static_weights = global_config.get('load_static_weights', False)
|
||||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
for idx, pretrained in enumerate(pretrained_model):
|
||||
load_static = load_static_weights[idx]
|
||||
load_dygraph_pretrain(
|
||||
model, logger, path=pretrained, load_static_weights=load_static)
|
||||
for pretrained in pretrained_model:
|
||||
if not (os.path.isdir(pretrained) or
|
||||
os.path.exists(pretrained + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(pretrained))
|
||||
param_state_dict = paddle.load(pretrained + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
logger.info("load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
else:
|
||||
|
@ -121,6 +89,55 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
return best_model_dict
|
||||
|
||||
|
||||
def load_dygraph_params(config, model, logger, optimizer):
|
||||
ckp = config['Global']['checkpoints']
|
||||
if ckp and os.path.exists(ckp + ".pdparams"):
|
||||
pre_best_model_dict = init_model(config, model, optimizer)
|
||||
return pre_best_model_dict
|
||||
else:
|
||||
pm = config['Global']['pretrained_model']
|
||||
if pm is None:
|
||||
return {}
|
||||
if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
|
||||
logger.info(f"The pretrained_model {pm} does not exists!")
|
||||
return {}
|
||||
pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
|
||||
params = paddle.load(pm)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
logger.info(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
logger.info(f"loaded pretrained_model successful from {pm}")
|
||||
return {}
|
||||
|
||||
def load_pretrained_params(model, path):
|
||||
if path is None:
|
||||
return False
|
||||
if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
|
||||
print(f"The pretrained_model {path} does not exists!")
|
||||
return False
|
||||
|
||||
path = path if path.endswith('.pdparams') else path + '.pdparams'
|
||||
params = paddle.load(path)
|
||||
state_dict = model.state_dict()
|
||||
new_state_dict = {}
|
||||
for k1, k2 in zip(state_dict.keys(), params.keys()):
|
||||
if list(state_dict[k1].shape) == list(params[k2].shape):
|
||||
new_state_dict[k1] = params[k2]
|
||||
else:
|
||||
print(
|
||||
f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
|
||||
)
|
||||
model.set_state_dict(new_state_dict)
|
||||
print(f"load pretrain successful from {path}")
|
||||
return model
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
model_path,
|
||||
|
|
|
@ -0,0 +1,189 @@
|
|||
English | [简体中文](README_ch.md)
|
||||
|
||||
# PP-Structure
|
||||
|
||||
PP-Structure is an OCR toolkit that can be used for complex documents analysis. The main features are as follows:
|
||||
- Support the layout analysis of documents, divide the documents into 5 types of areas **text, title, table, image and list** (conjunction with Layout-Parser)
|
||||
- Support to extract the texts from the text, title, picture and list areas (used in conjunction with PP-OCR)
|
||||
- Support to extract excel files from the table areas
|
||||
- Support python whl package and command line usage, easy to use
|
||||
- Support custom training for layout analysis and table structure tasks
|
||||
|
||||
## 1. Visualization
|
||||
|
||||
<img src="../doc/table/ppstructure.GIF" width="100%"/>
|
||||
|
||||
|
||||
|
||||
## 2. Installation
|
||||
|
||||
### 2.1 Install requirements
|
||||
|
||||
- **(1) Install PaddlePaddle**
|
||||
|
||||
```bash
|
||||
pip3 install --upgrade pip
|
||||
|
||||
# GPU
|
||||
python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# CPU
|
||||
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。
|
||||
```
|
||||
|
||||
- **(2) Install Layout-Parser**
|
||||
|
||||
```bash
|
||||
pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
|
||||
|
||||
- **(1) PIP install PaddleOCR whl package(inference only)**
|
||||
|
||||
```bash
|
||||
pip install "paddleocr>=2.2"
|
||||
```
|
||||
|
||||
- **(2) Clone PaddleOCR(Inference+training)**
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
```
|
||||
|
||||
|
||||
## 3. Quick Start
|
||||
|
||||
### 3.1 Use by command line
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir=../doc/table/1.png --type=structure
|
||||
```
|
||||
|
||||
### 3.2 Use by python API
|
||||
|
||||
```python
|
||||
import os
|
||||
import cv2
|
||||
from paddleocr import PPStructure,draw_structure_result,save_structure_res
|
||||
|
||||
table_engine = PPStructure(show_log=True)
|
||||
|
||||
save_folder = './output/table'
|
||||
img_path = '../doc/table/1.png'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
|
||||
|
||||
for line in result:
|
||||
line.pop('img')
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = '../doc/fonts/simfang.ttf'
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_structure_result(image, result,font_path=font_path)
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
### 3.3 Returned results format
|
||||
The returned results of PP-Structure is a list composed of a dict, an example is as follows
|
||||
|
||||
```shell
|
||||
[
|
||||
{ 'type': 'Text',
|
||||
'bbox': [34, 432, 345, 462],
|
||||
'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
|
||||
[('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
|
||||
}
|
||||
]
|
||||
```
|
||||
The description of each field in dict is as follows
|
||||
|
||||
| Parameter | Description |
|
||||
| --------------- | -------------|
|
||||
|type|Type of image area|
|
||||
|bbox|The coordinates of the image area in the original image, respectively [left upper x, left upper y, right bottom x, right bottom y]|
|
||||
|res|OCR or table recognition result of image area。<br> Table: HTML string of the table; <br> OCR: A tuple containing the detection coordinates and recognition results of each single line of text|
|
||||
|
||||
|
||||
### 3.4 Parameter description:
|
||||
|
||||
| Parameter | Description | Default value |
|
||||
| --------------- | ---------------------------------------- | ------------------------------------------- |
|
||||
| output | The path where excel and recognition results are saved | ./output/table |
|
||||
| table_max_len | The long side of the image is resized in table structure model | 488 |
|
||||
| table_model_dir | inference model path of table structure model | None |
|
||||
| table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.tx |
|
||||
|
||||
Most of the parameters are consistent with the paddleocr whl package, see [doc of whl](../doc/doc_en/whl_en.md)
|
||||
|
||||
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
|
||||
|
||||
## 4. PP-Structure Pipeline
|
||||
|
||||
the process is as follows
|
||||
![pipeline](../doc/table/pipeline_en.jpg)
|
||||
|
||||
In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR.
|
||||
|
||||
### 4.1 LayoutParser
|
||||
|
||||
Layout analysis divides the document data into regions, including the use of Python scripts for layout analysis tools, extraction of special category detection boxes, performance indicators, and custom training layout analysis models. For details, please refer to [document](layout/README_en.md).
|
||||
|
||||
### 4.2 Table Recognition
|
||||
|
||||
Table Recognition converts table image into excel documents, which include the detection and recognition of table text and the prediction of table structure and cell coordinates. For detailed, please refer to [document](table/README.md)
|
||||
|
||||
## 5. Prediction by inference engine
|
||||
|
||||
Use the following commands to complete the inference.
|
||||
|
||||
```python
|
||||
cd PaddleOCR/ppstructure
|
||||
|
||||
# download model
|
||||
mkdir inference && cd inference
|
||||
# Download the detection model of the ultra-lightweight Chinese OCR model and uncompress it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
# Download the recognition model of the ultra-lightweight Chinese OCR model and uncompress it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
|
||||
# Download the table structure model of the ultra-lightweight Chinese OCR model and uncompress it
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
|
||||
cd ..
|
||||
|
||||
python3 predict_system.py --det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer --rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/1.png --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --output=../output/table --vis_font_path=../doc/fonts/simfang.ttf
|
||||
```
|
||||
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
|
||||
|
||||
**Model List**
|
||||
|
||||
|
||||
|model name|description|config|model size|download|
|
||||
| --- | --- | --- | --- | --- |
|
||||
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction for English table scenarios|[table_mv3.yml](../configs/table/table_mv3.yml)|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|
||||
|
||||
**Model List**
|
||||
|
||||
LayoutParser model
|
||||
|
||||
|model name|description|download|
|
||||
| --- | --- | --- |
|
||||
| ppyolov2_r50vd_dcn_365e_publaynet | The layout analysis model trained on the PubLayNet data set can be divided into 5 types of areas **text, title, table, picture and list** | [PubLayNet](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_publaynet.tar) |
|
||||
| ppyolov2_r50vd_dcn_365e_tableBank_word | The layout analysis model trained on the TableBank Word dataset can only detect tables | [TableBank Word](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_word.tar) |
|
||||
| ppyolov2_r50vd_dcn_365e_tableBank_latex | The layout analysis model trained on the TableBank Latex dataset can only detect tables | [TableBank Latex](https://paddle-model-ecology.bj.bcebos.com/model/layout-parser/ppyolov2_r50vd_dcn_365e_tableBank_latex.tar) |
|
||||
|
||||
OCR and table recognition model
|
||||
|
||||
|model name|description|model size|download|
|
||||
| --- | --- | --- | --- |
|
||||
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
|
||||
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
|
||||
|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
|
||||
|
||||
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
|