fix conflict
This commit is contained in:
commit
f6daae41e5
|
@ -1031,7 +1031,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
|
|
||||||
for box in self.result_dic:
|
for box in self.result_dic:
|
||||||
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
|
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
|
||||||
if trans_dic["label"] is "" and mode == 'Auto':
|
if trans_dic["label"] == "" and mode == 'Auto':
|
||||||
continue
|
continue
|
||||||
shapes.append(trans_dic)
|
shapes.append(trans_dic)
|
||||||
|
|
||||||
|
@ -1450,7 +1450,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
item = QListWidgetItem(closeicon, filename)
|
item = QListWidgetItem(closeicon, filename)
|
||||||
self.fileListWidget.addItem(item)
|
self.fileListWidget.addItem(item)
|
||||||
|
|
||||||
print('dirPath in importDirImages is', dirpath)
|
print('DirPath in importDirImages is', dirpath)
|
||||||
self.iconlist.clear()
|
self.iconlist.clear()
|
||||||
self.additems5(dirpath)
|
self.additems5(dirpath)
|
||||||
self.changeFileFolder = True
|
self.changeFileFolder = True
|
||||||
|
@ -1459,7 +1459,6 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
self.reRecogButton.setEnabled(True)
|
self.reRecogButton.setEnabled(True)
|
||||||
self.actions.AutoRec.setEnabled(True)
|
self.actions.AutoRec.setEnabled(True)
|
||||||
self.actions.reRec.setEnabled(True)
|
self.actions.reRec.setEnabled(True)
|
||||||
self.actions.saveLabel.setEnabled(True)
|
|
||||||
|
|
||||||
|
|
||||||
def openPrevImg(self, _value=False):
|
def openPrevImg(self, _value=False):
|
||||||
|
@ -1764,7 +1763,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
QMessageBox.information(self, "Information", msg)
|
QMessageBox.information(self, "Information", msg)
|
||||||
return
|
return
|
||||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||||
if result[0][0] is not '':
|
if result[0][0] != '':
|
||||||
result.insert(0, box)
|
result.insert(0, box)
|
||||||
print('result in reRec is ', result)
|
print('result in reRec is ', result)
|
||||||
self.result_dic.append(result)
|
self.result_dic.append(result)
|
||||||
|
@ -1795,7 +1794,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
QMessageBox.information(self, "Information", msg)
|
QMessageBox.information(self, "Information", msg)
|
||||||
return
|
return
|
||||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||||
if result[0][0] is not '':
|
if result[0][0] != '':
|
||||||
result.insert(0, box)
|
result.insert(0, box)
|
||||||
print('result in reRec is ', result)
|
print('result in reRec is ', result)
|
||||||
if result[1][0] == shape.label:
|
if result[1][0] == shape.label:
|
||||||
|
@ -1862,6 +1861,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
for each in states:
|
for each in states:
|
||||||
file, state = each.split('\t')
|
file, state = each.split('\t')
|
||||||
self.fileStatedict[file] = 1
|
self.fileStatedict[file] = 1
|
||||||
|
self.actions.saveLabel.setEnabled(True)
|
||||||
|
self.actions.saveRec.setEnabled(True)
|
||||||
|
|
||||||
|
|
||||||
def saveFilestate(self):
|
def saveFilestate(self):
|
||||||
|
@ -1919,22 +1920,29 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
|
|
||||||
rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
|
rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
|
||||||
crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
|
crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/'
|
||||||
|
ques_img = []
|
||||||
if not os.path.exists(crop_img_dir):
|
if not os.path.exists(crop_img_dir):
|
||||||
os.mkdir(crop_img_dir)
|
os.mkdir(crop_img_dir)
|
||||||
|
|
||||||
with open(rec_gt_dir, 'w', encoding='utf-8') as f:
|
with open(rec_gt_dir, 'w', encoding='utf-8') as f:
|
||||||
for key in self.fileStatedict:
|
for key in self.fileStatedict:
|
||||||
idx = self.getImglabelidx(key)
|
idx = self.getImglabelidx(key)
|
||||||
|
try:
|
||||||
|
img = cv2.imread(key)
|
||||||
for i, label in enumerate(self.PPlabel[idx]):
|
for i, label in enumerate(self.PPlabel[idx]):
|
||||||
if label['difficult']: continue
|
if label['difficult']: continue
|
||||||
img = cv2.imread(key)
|
|
||||||
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
|
img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32))
|
||||||
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
|
img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg'
|
||||||
cv2.imwrite(crop_img_dir+img_name, img_crop)
|
cv2.imwrite(crop_img_dir+img_name, img_crop)
|
||||||
f.write('crop_img/'+ img_name + '\t')
|
f.write('crop_img/'+ img_name + '\t')
|
||||||
f.write(label['transcription'] + '\n')
|
f.write(label['transcription'] + '\n')
|
||||||
|
except Exception as e:
|
||||||
QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir))
|
ques_img.append(key)
|
||||||
|
print("Can not read image ",e)
|
||||||
|
if ques_img:
|
||||||
|
QMessageBox.information(self, "Information", "The following images can not be saved, "
|
||||||
|
"please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img))
|
||||||
|
QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir))
|
||||||
|
|
||||||
def speedChoose(self):
|
def speedChoose(self):
|
||||||
if self.labelDialogOption.isChecked():
|
if self.labelDialogOption.isChecked():
|
||||||
|
@ -1991,7 +1999,7 @@ if __name__ == '__main__':
|
||||||
resource_file = './libs/resources.py'
|
resource_file = './libs/resources.py'
|
||||||
if not os.path.exists(resource_file):
|
if not os.path.exists(resource_file):
|
||||||
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
|
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
|
||||||
assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
|
assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
|
||||||
"directory resources.py "
|
"directory resources.py "
|
||||||
import libs.resources
|
import libs.resources
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|
|
@ -5,7 +5,7 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
|
||||||
|
|
||||||
## Notice
|
## Notice
|
||||||
PaddleOCR supports both dynamic graph and static graph programming paradigm
|
PaddleOCR supports both dynamic graph and static graph programming paradigm
|
||||||
- Dynamic graph: dygraph branch (default), **supported by paddle 2.0rc1+ ([installation](./doc/doc_en/installation_en.md))**
|
- Dynamic graph: dygraph branch (default), **supported by paddle 2.0.0 ([installation](./doc/doc_en/installation_en.md))**
|
||||||
- Static graph: develop branch
|
- Static graph: develop branch
|
||||||
|
|
||||||
**Recent updates**
|
**Recent updates**
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
|
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
|
||||||
## 注意
|
## 注意
|
||||||
PaddleOCR同时支持动态图与静态图两种编程范式
|
PaddleOCR同时支持动态图与静态图两种编程范式
|
||||||
- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0rc1+([快速安装](./doc/doc_ch/installation.md))
|
- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0.0([快速安装](./doc/doc_ch/installation.md))
|
||||||
- 静态图版本:develop分支
|
- 静态图版本:develop分支
|
||||||
|
|
||||||
**近期更新**
|
**近期更新**
|
||||||
|
- 2021.2.1 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数162个,每周一都会更新,欢迎大家持续关注。
|
||||||
- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802)
|
- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802)
|
||||||
- 2021.1.25 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数157个,每周一都会更新,欢迎大家持续关注。
|
|
||||||
- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
|
- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
|
||||||
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
|
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
|
||||||
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
|
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
Global:
|
Global:
|
||||||
use_gpu: true
|
use_gpu: True
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -59,7 +59,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -78,7 +78,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -77,7 +77,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 72
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/rec_mv3_tps_bilstm_att/
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0005
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.00001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RARE
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 0.1
|
||||||
|
model_name: small
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 96
|
||||||
|
Head:
|
||||||
|
name: AttentionHead
|
||||||
|
hidden_size: 96
|
||||||
|
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: AttentionLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: AttnLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../training/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 256
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../validation/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 1
|
|
@ -63,7 +63,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -82,7 +82,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -77,7 +77,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -56,7 +56,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -75,7 +75,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -0,0 +1,101 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 400
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/b3_rare_r34_none_gru/
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0005
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.00000
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RARE
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 0.1
|
||||||
|
model_name: large
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 34
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 256 #96
|
||||||
|
Head:
|
||||||
|
name: AttentionHead # AttentionHead
|
||||||
|
hidden_size: 256 #
|
||||||
|
l2_decay: 0.00001
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: AttentionLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: AttnLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../training/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 256
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../validation/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 8
|
|
@ -62,7 +62,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -81,7 +81,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 72
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 5
|
||||||
|
save_model_dir: ./output/rec/srn_new
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
num_heads: 8
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
clip_norm: 10.0
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: SRN
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNetFPN
|
||||||
|
Head:
|
||||||
|
name: SRNHead
|
||||||
|
max_text_length: 25
|
||||||
|
num_heads: 8
|
||||||
|
num_encoder_TUs: 2
|
||||||
|
num_decoder_TUs: 4
|
||||||
|
hidden_dims: 512
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SRNLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SRNLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/srn_train_data_duiqi
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SRNLabelEncode: # Class handling label
|
||||||
|
- SRNRecResizeImg:
|
||||||
|
image_shape: [1, 64, 256]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image',
|
||||||
|
'label',
|
||||||
|
'length',
|
||||||
|
'encoder_word_pos',
|
||||||
|
'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1',
|
||||||
|
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: False
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/evaluation
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SRNLabelEncode: # Class handling label
|
||||||
|
- SRNRecResizeImg:
|
||||||
|
image_shape: [1, 64, 256]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image',
|
||||||
|
'label',
|
||||||
|
'length',
|
||||||
|
'encoder_word_pos',
|
||||||
|
'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1',
|
||||||
|
'gsrm_slf_attn_bias2']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 32
|
||||||
|
num_workers: 4
|
|
@ -1,5 +1,5 @@
|
||||||
# Version: 2.0.0
|
# Version: 2.0.0
|
||||||
FROM registry.baidubce.com/paddlepaddle/paddle:2.0.0rc1
|
FROM registry.baidubce.com/paddlepaddle/paddle:2.0.0
|
||||||
|
|
||||||
# PaddleOCR base on Python3.7
|
# PaddleOCR base on Python3.7
|
||||||
RUN pip3.7 install --upgrade pip -i https://mirror.baidu.com/pypi/simple
|
RUN pip3.7 install --upgrade pip -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# Version: 2.0.0
|
# Version: 2.0.0
|
||||||
FROM egistry.baidubce.com/paddlepaddle/paddle:2.0.0rc1-gpu-cuda10.0-cudnn7
|
FROM registry.baidubce.com/paddlepaddle/paddle:2.0.0-gpu-cuda10.1-cudnn7
|
||||||
|
|
||||||
# PaddleOCR base on Python3.7
|
# PaddleOCR base on Python3.7
|
||||||
RUN pip3.7 install --upgrade pip -i https://mirror.baidu.com/pypi/simple
|
RUN pip3.7 install --upgrade pip -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
|
@ -9,43 +9,38 @@
|
||||||
|
|
||||||
## PaddleOCR常见问题汇总(持续更新)
|
## PaddleOCR常见问题汇总(持续更新)
|
||||||
|
|
||||||
* [近期更新(2021.1.25)](#近期更新)
|
* [近期更新(2021.2.1)](#近期更新)
|
||||||
* [【精选】OCR精选10个问题](#OCR精选10个问题)
|
* [【精选】OCR精选10个问题](#OCR精选10个问题)
|
||||||
* [【理论篇】OCR通用32个问题](#OCR通用问题)
|
* [【理论篇】OCR通用32个问题](#OCR通用问题)
|
||||||
* [基础知识7题](#基础知识)
|
* [基础知识7题](#基础知识)
|
||||||
* [数据集7题](#数据集2)
|
* [数据集7题](#数据集2)
|
||||||
* [模型训练调优18题](#模型训练调优2)
|
* [模型训练调优18题](#模型训练调优2)
|
||||||
* [【实战篇】PaddleOCR实战115个问题](#PaddleOCR实战问题)
|
* [【实战篇】PaddleOCR实战120个问题](#PaddleOCR实战问题)
|
||||||
* [使用咨询38题](#使用咨询)
|
* [使用咨询38题](#使用咨询)
|
||||||
* [数据集17题](#数据集3)
|
* [数据集18题](#数据集3)
|
||||||
* [模型训练调优28题](#模型训练调优3)
|
* [模型训练调优30题](#模型训练调优3)
|
||||||
* [预测部署32题](#预测部署3)
|
* [预测部署34题](#预测部署3)
|
||||||
|
|
||||||
|
|
||||||
<a name="近期更新"></a>
|
<a name="近期更新"></a>
|
||||||
## 近期更新(2021.1.25)
|
## 近期更新(2021.2.1)
|
||||||
|
|
||||||
#### Q3.1.37: 小语种模型只有识别模型,没有检测模型吗?
|
#### Q3.2.18: PaddleOCR动态图版本如何finetune?
|
||||||
|
**A**:finetune需要将配置文件里的 Global.load_static_weights设置为false,如果没有此字段可以手动添加,然后将模型地址放到Global.pretrained_model字段下即可。
|
||||||
|
|
||||||
**A**:小语种(包括纯英文数字)的检测模型和中文的检测模型是共用的,在训练中文检测模型时加入了多语言数据。https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/models_list_en.md#1-text-detection-model。
|
|
||||||
|
|
||||||
#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
|
#### Q3.3.29: 微调v1.1预训练的模型,可以直接用文字垂直排列和上下颠倒的图片吗?还是必须要水平排列的?
|
||||||
|
**A**:1.1和2.0的模型一样,微调时,垂直排列的文字需要逆时针旋转 90° 后加入训练,上下颠倒的需要旋转为水平的。
|
||||||
|
|
||||||
**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0rc1。
|
#### Q3.3.30: 模型训练过程中如何得到 best_accuracy 模型?
|
||||||
|
**A**:配置文件里的eval_batch_step字段用来控制多少次iter进行一次eval,在eval完成后会自动生成 best_accuracy 模型,所以如果希望很快就能拿到best_accuracy模型,可以将eval_batch_step改小一点(例如,10)。
|
||||||
|
|
||||||
#### Q3.4.30: PaddleOCR是否支持在华为鲲鹏920CPU上部署?
|
#### Q3.4.33: 如何多进程运行paddleocr?
|
||||||
|
**A**:实例化多个paddleocr服务,然后将服务注册到注册中心,之后通过注册中心统一调度即可,关于注册中心,可以搜索eureka了解一下具体使用,其他的注册中心也行。
|
||||||
|
|
||||||
**A**:目前Paddle的预测库是支持华为鲲鹏920CPU的,但是OCR还没在这些芯片上测试过,可以自己调试,有问题反馈给我们。
|
|
||||||
|
|
||||||
#### Q3.4.31: 采用Paddle-Lite进行端侧部署,出现问题,环境没问题。
|
#### Q3.4.34: 2.0训练出来的模型,能否在1.1版本上进行部署?
|
||||||
|
**A**:这个是不建议的,2.0训练出来的模型建议使用dygraph分支里提供的部署代码。
|
||||||
**A**:如果你的预测库是自己编译的,那么你的nb文件也要自己编译,用同一个lite版本。不能直接用下载的nb文件,因为版本不同。
|
|
||||||
|
|
||||||
#### Q3.4.32: PaddleOCR的模型支持onnx转换吗?
|
|
||||||
|
|
||||||
**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
|
|
||||||
Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
|
|
||||||
Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
|
|
||||||
|
|
||||||
<a name="OCR精选10个问题"></a>
|
<a name="OCR精选10个问题"></a>
|
||||||
## 【精选】OCR精选10个问题
|
## 【精选】OCR精选10个问题
|
||||||
|
@ -397,13 +392,13 @@ Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2
|
||||||
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
|
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
|
||||||
|
|
||||||
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
|
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
|
||||||
**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0rc1的版本,安装方式为
|
**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0的版本,安装方式为
|
||||||
```
|
```
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
|
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
|
||||||
**A**:这个问题是glibc版本不足导致的,Paddle2.0rc1版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
|
**A**:这个问题是glibc版本不足导致的,Paddle2.0.0版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
|
||||||
如果您的环境不满足这个要求,或者使用的docker镜像为:
|
如果您的环境不满足这个要求,或者使用的docker镜像为:
|
||||||
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
|
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
|
||||||
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`。
|
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`。
|
||||||
|
@ -415,7 +410,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
|
||||||
|
|
||||||
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
|
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
|
||||||
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
|
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
|
||||||
- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0rc1版本,目前仍在开发中。
|
- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0.0版本。
|
||||||
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
|
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
|
||||||
|
|
||||||
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
|
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
|
||||||
|
@ -432,7 +427,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
|
||||||
|
|
||||||
#### Q3.1.27: 如何可视化acc,loss曲线图,模型网络结构图等?
|
#### Q3.1.27: 如何可视化acc,loss曲线图,模型网络结构图等?
|
||||||
|
|
||||||
**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/guides/03_VisualDL/visualdl.html)。
|
**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/03_VisualDL/visualdl.html)。
|
||||||
|
|
||||||
#### Q3.1.28: 在使用StyleText数据合成工具的时候,报错`ModuleNotFoundError: No module named 'utils.config'`,这是为什么呢?
|
#### Q3.1.28: 在使用StyleText数据合成工具的时候,报错`ModuleNotFoundError: No module named 'utils.config'`,这是为什么呢?
|
||||||
|
|
||||||
|
@ -451,7 +446,7 @@ https://github.com/PaddlePaddle/PaddleOCR/blob/de3e2e7cd3b8b65ee02d7a41e570fa5b5
|
||||||
|
|
||||||
#### Q3.1.31: 怎么输出网络结构以及每层的参数信息?
|
#### Q3.1.31: 怎么输出网络结构以及每层的参数信息?
|
||||||
|
|
||||||
**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/hapi/model_summary/summary_cn.html#summary。
|
**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model_summary/summary_cn.html。
|
||||||
|
|
||||||
#### Q3.1.32 能否修改StyleText配置文件中的分辨率?
|
#### Q3.1.32 能否修改StyleText配置文件中的分辨率?
|
||||||
|
|
||||||
|
@ -485,7 +480,7 @@ StyleText的用途主要是:提取style_image中的字体、背景等style信
|
||||||
|
|
||||||
#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
|
#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
|
||||||
|
|
||||||
**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0rc1。
|
**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0。
|
||||||
|
|
||||||
<a name="数据集3"></a>
|
<a name="数据集3"></a>
|
||||||
### 数据集
|
### 数据集
|
||||||
|
@ -578,6 +573,9 @@ StyleText的用途主要是:提取style_image中的字体、背景等style信
|
||||||
|
|
||||||
**A**:PPOCRLabel可运行于Linux、Windows、MacOS等多种系统。操作步骤可以参考文档,https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/README.md
|
**A**:PPOCRLabel可运行于Linux、Windows、MacOS等多种系统。操作步骤可以参考文档,https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/README.md
|
||||||
|
|
||||||
|
#### Q3.2.18: PaddleOCR动态图版本如何finetune?
|
||||||
|
**A**:finetune需要将配置文件里的 Global.load_static_weights设置为false,如果没有此字段可以手动添加,然后将模型地址放到Global.pretrained_model字段下即可。
|
||||||
|
|
||||||
<a name="模型训练调优3"></a>
|
<a name="模型训练调优3"></a>
|
||||||
|
|
||||||
### 模型训练调优
|
### 模型训练调优
|
||||||
|
@ -723,6 +721,12 @@ ps -axu | grep train.py | awk '{print $2}' | xargs kill -9
|
||||||
|
|
||||||
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
|
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
|
||||||
|
|
||||||
|
#### Q3.3.29: 微调v1.1预训练的模型,可以直接用文字垂直排列和上下颠倒的图片吗?还是必须要水平排列的?
|
||||||
|
**A**:1.1和2.0的模型一样,微调时,垂直排列的文字需要逆时针旋转 90°后加入训练,上下颠倒的需要旋转为水平的。
|
||||||
|
|
||||||
|
#### Q3.3.30: 模型训练过程中如何得到 best_accuracy 模型?
|
||||||
|
**A**:配置文件里的eval_batch_step字段用来控制多少次iter进行一次eval,在eval完成后会自动生成 best_accuracy 模型,所以如果希望很快就能拿到best_accuracy模型,可以将eval_batch_step改小一点,如改为[10,10],这样表示第10次迭代后,以后没隔10个迭代就进行一次模型的评估。
|
||||||
|
|
||||||
<a name="预测部署3"></a>
|
<a name="预测部署3"></a>
|
||||||
|
|
||||||
### 预测部署
|
### 预测部署
|
||||||
|
@ -878,3 +882,10 @@ img = cv.imdecode(img_array, -1)
|
||||||
**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
|
**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
|
||||||
Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
|
Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
|
||||||
Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
|
Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
|
||||||
|
|
||||||
|
|
||||||
|
#### Q3.4.33: 如何多进程运行paddleocr?
|
||||||
|
**A**:实例化多个paddleocr服务,然后将服务注册到注册中心,之后通过注册中心统一调度即可,关于注册中心,可以搜索eureka了解一下具体使用,其他的注册中心也行。
|
||||||
|
|
||||||
|
#### Q3.4.34: 2.0训练出来的模型,能否在1.1版本上进行部署?
|
||||||
|
**A**:这个是不建议的,2.0训练出来的模型建议使用dygraph分支里提供的部署代码。
|
||||||
|
|
|
@ -40,8 +40,8 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
|
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
|
||||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
|
||||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
|
||||||
|
@ -53,5 +53,9 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||||
|
|
||||||
|
|
||||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||||
|
|
|
@ -63,7 +63,7 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
|
||||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||||
|
|
||||||
```
|
```
|
||||||
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
|
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号。
|
||||||
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
|
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||||
```
|
```
|
||||||
|
|
|
@ -76,7 +76,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
|
||||||
# 单机单卡训练 mv3_db 模型
|
# 单机单卡训练 mv3_db 模型
|
||||||
python3 tools/train.py -c configs/det/det_mv3_db.yml \
|
python3 tools/train.py -c configs/det/det_mv3_db.yml \
|
||||||
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID;如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
|
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
|
||||||
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||||
```
|
```
|
||||||
|
|
|
@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型)
|
||||||
- [三、文本识别模型推理](#文本识别模型推理)
|
- [三、文本识别模型推理](#文本识别模型推理)
|
||||||
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
|
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
|
||||||
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
|
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
|
||||||
- [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
- [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理)
|
||||||
- [4. 多语言模型的推理](#多语言模型的推理)
|
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||||
|
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||||
|
|
||||||
- [四、方向分类模型推理](#方向识别模型推理)
|
- [四、方向分类模型推理](#方向识别模型推理)
|
||||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||||
|
@ -295,8 +296,20 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
```
|
```
|
||||||
|
<a name="基于SRN损失的识别模型推理"></a>
|
||||||
|
### 3. 基于SRN损失的识别模型推理
|
||||||
|
基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。
|
||||||
|
同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256"
|
||||||
|
|
||||||
### 3. 自定义文本识别字典的推理
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
|
||||||
|
--rec_model_dir="./inference/srn/" \
|
||||||
|
--rec_image_shape="1, 64, 256" \
|
||||||
|
--rec_char_type="en" \
|
||||||
|
--rec_algorithm="SRN"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 自定义文本识别字典的推理
|
||||||
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
|
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -304,7 +317,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="多语言模型的推理"></a>
|
<a name="多语言模型的推理"></a>
|
||||||
### 4. 多语言模型的推理
|
### 5. 多语言模型的推理
|
||||||
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
|
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
|
||||||
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
|
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
|
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
|
||||||
PaddleOCR 工作环境
|
PaddleOCR 工作环境
|
||||||
- PaddlePaddle 1.8+ ,推荐使用 PaddlePaddle 2.0rc1
|
- PaddlePaddle 2.0.0
|
||||||
- python3.7
|
- python3.7
|
||||||
- glibc 2.23
|
- glibc 2.23
|
||||||
- cuDNN 7.6+ (GPU)
|
- cuDNN 7.6+ (GPU)
|
||||||
|
@ -35,11 +35,11 @@ sudo docker container exec -it ppocr /bin/bash
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
|
|
||||||
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
|
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
如果您的机器是CPU,请运行以下命令安装
|
如果您的机器是CPU,请运行以下命令安装
|
||||||
|
|
||||||
python3 -m pip install paddlepaddle==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||||
```
|
```
|
||||||
|
|
|
@ -36,6 +36,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
||||||
* 数据下载
|
* 数据下载
|
||||||
|
|
||||||
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
|
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
|
||||||
|
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
|
||||||
|
|
||||||
<a name="自定义数据集"></a>
|
<a name="自定义数据集"></a>
|
||||||
* 使用自己数据集
|
* 使用自己数据集
|
||||||
|
@ -200,6 +201,9 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
||||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||||
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
|
||||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||||
|
|
||||||
|
|
|
@ -42,8 +42,8 @@ PaddleOCR open-source text recognition algorithms list:
|
||||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
|
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
|
||||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
|
||||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||||
|
|
||||||
|
@ -55,5 +55,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
||||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||||
|
|
||||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||||
|
|
|
@ -66,7 +66,7 @@ Start training:
|
||||||
```
|
```
|
||||||
# Set PYTHONPATH path
|
# Set PYTHONPATH path
|
||||||
export PYTHONPATH=$PYTHONPATH:.
|
export PYTHONPATH=$PYTHONPATH:.
|
||||||
# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
|
# GPU training Support single card and multi-card training, specify the card number through --gpus.
|
||||||
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
|
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||||
```
|
```
|
||||||
|
|
|
@ -76,7 +76,7 @@ You can also use `-o` to change the training parameters without modifying the ym
|
||||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||||
|
|
||||||
# multi-GPU training
|
# multi-GPU training
|
||||||
# Set the GPU ID used by the '--gpus' parameter; If your paddle version is less than 2.0rc1, please use '--selected_gpus'
|
# Set the GPU ID used by the '--gpus' parameter.
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ Next, we first introduce how to convert a trained model into an inference model,
|
||||||
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
|
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
|
||||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
|
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
|
||||||
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
|
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
|
||||||
|
- [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION)
|
||||||
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
||||||
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
|
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
|
||||||
|
|
||||||
|
@ -304,8 +305,23 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
<a name="SRN-BASED_RECOGNITION"></a>
|
||||||
|
### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE
|
||||||
|
|
||||||
|
The recognition model based on SRN requires additional setting of the recognition algorithm parameter
|
||||||
|
--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent
|
||||||
|
with the training, such as: --rec_image_shape="1, 64, 256"
|
||||||
|
|
||||||
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
|
||||||
|
--rec_model_dir="./inference/srn/" \
|
||||||
|
--rec_image_shape="1, 64, 256" \
|
||||||
|
--rec_char_type="en" \
|
||||||
|
--rec_algorithm="SRN"
|
||||||
|
```
|
||||||
|
|
||||||
<a name="USING_CUSTOM_CHARACTERS"></a>
|
<a name="USING_CUSTOM_CHARACTERS"></a>
|
||||||
### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
|
### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
|
||||||
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
|
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -313,7 +329,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
||||||
```
|
```
|
||||||
|
|
||||||
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
|
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
|
||||||
### 4. MULTILINGAUL MODEL INFERENCE
|
### 5. MULTILINGAUL MODEL INFERENCE
|
||||||
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
|
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
|
||||||
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
|
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility.
|
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility.
|
||||||
|
|
||||||
PaddleOCR working environment:
|
PaddleOCR working environment:
|
||||||
- PaddlePaddle 1.8+, Recommend PaddlePaddle 2.0rc1
|
- PaddlePaddle 2.0.0
|
||||||
- python3.7
|
- python3.7
|
||||||
- glibc 2.23
|
- glibc 2.23
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ sudo docker container exec -it ppocr /bin/bash
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
|
|
||||||
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
|
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
# If you only have cpu on your machine, please run the following command to install
|
# If you only have cpu on your machine, please run the following command to install
|
||||||
python3 -m pip install paddlepaddle==2.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
```
|
```
|
||||||
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
||||||
|
|
||||||
|
|
|
@ -89,7 +89,7 @@ python3 generate_multi_language_configs.py -l it \
|
||||||
|model name|description|config|model size|download|
|
|model name|description|config|model size|download|
|
||||||
| --- | --- | --- | --- | --- |
|
| --- | --- | --- | --- | --- |
|
||||||
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
||||||
| german_mobile_v2.0_rec |Lightweight model for French recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
| german_mobile_v2.0_rec |Lightweight model for German recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
||||||
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
||||||
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
||||||
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
||||||
|
|
|
@ -195,6 +195,10 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
||||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||||
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
|
||||||
|
|
||||||
For training Chinese data, it is recommended to use
|
For training Chinese data, it is recommended to use
|
||||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||||
|
|
|
@ -33,7 +33,7 @@ import paddle.distributed as dist
|
||||||
|
|
||||||
from ppocr.data.imaug import transform, create_operators
|
from ppocr.data.imaug import transform, create_operators
|
||||||
from ppocr.data.simple_dataset import SimpleDataSet
|
from ppocr.data.simple_dataset import SimpleDataSet
|
||||||
from ppocr.data.lmdb_dataset import LMDBDateSet
|
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||||
|
|
||||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||||
|
|
||||||
|
@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
||||||
def build_dataloader(config, mode, device, logger, seed=None):
|
def build_dataloader(config, mode, device, logger, seed=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
support_dict = ['SimpleDataSet', 'LMDBDataSet']
|
||||||
module_name = config[mode]['dataset']['name']
|
module_name = config[mode]['dataset']['name']
|
||||||
assert module_name in support_dict, Exception(
|
assert module_name in support_dict, Exception(
|
||||||
'DataSet only support {}'.format(support_dict))
|
'DataSet only support {}'.format(support_dict))
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
||||||
from .make_shrink_map import MakeShrinkMap
|
from .make_shrink_map import MakeShrinkMap
|
||||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||||
|
|
||||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
|
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .operators import *
|
from .operators import *
|
||||||
from .label_ops import *
|
from .label_ops import *
|
||||||
|
|
|
@ -102,6 +102,8 @@ class BaseRecLabelEncode(object):
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
||||||
self.max_text_len = max_text_length
|
self.max_text_len = max_text_length
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
|
@ -197,16 +199,76 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
||||||
super(AttnLabelEncode,
|
super(AttnLabelEncode,
|
||||||
self).__init__(max_text_length, character_dict_path,
|
self).__init__(max_text_length, character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
self.beg_str = "sos"
|
|
||||||
self.end_str = "eos"
|
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
text = self.encode(text)
|
text = self.encode(text)
|
||||||
return text
|
if text is None:
|
||||||
|
return None
|
||||||
|
if len(text) >= self.max_text_len:
|
||||||
|
return None
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||||
|
- len(text) - 1)
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLabelEncode(BaseRecLabelEncode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length=25,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='en',
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SRNLabelEncode,
|
||||||
|
self).__init__(max_text_length, character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
|
text = self.encode(text)
|
||||||
|
char_num = len(self.character_str)
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
if len(text) > self.max_text_len:
|
||||||
|
return None
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text = text + [char_num] * (self.max_text_len - len(text))
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
def get_beg_end_flag_idx(self, beg_or_end):
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
if beg_or_end == "beg":
|
if beg_or_end == "beg":
|
||||||
|
|
|
@ -12,20 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -77,6 +63,26 @@ class RecResizeImg(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SRNRecResizeImg(object):
|
||||||
|
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
norm_img = resize_norm_img_srn(img, self.image_shape)
|
||||||
|
data['image'] = norm_img
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
|
||||||
|
|
||||||
|
data['encoder_word_pos'] = encoder_word_pos
|
||||||
|
data['gsrm_word_pos'] = gsrm_word_pos
|
||||||
|
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
|
||||||
|
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def resize_norm_img(img, image_shape):
|
def resize_norm_img(img, image_shape):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
|
@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
|
||||||
def resize_norm_img_chinese(img, image_shape):
|
def resize_norm_img_chinese(img, image_shape):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
# todo: change to 0 and modified image shape
|
# todo: change to 0 and modified image shape
|
||||||
max_wh_ratio = 0
|
max_wh_ratio = imgW * 1.0 / imgH
|
||||||
h, w = img.shape[0], img.shape[1]
|
h, w = img.shape[0], img.shape[1]
|
||||||
ratio = w * 1.0 / h
|
ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, ratio)
|
max_wh_ratio = max(max_wh_ratio, ratio)
|
||||||
|
@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
|
||||||
return padding_im
|
return padding_im
|
||||||
|
|
||||||
|
|
||||||
|
def resize_norm_img_srn(img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
||||||
|
[num_heads, 1, 1]) * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
||||||
|
[num_heads, 1, 1]) * [-1e9]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def flag():
|
def flag():
|
||||||
"""
|
"""
|
||||||
flag
|
flag
|
||||||
|
|
|
@ -20,9 +20,9 @@ import cv2
|
||||||
from .imaug import transform, create_operators
|
from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class LMDBDateSet(Dataset):
|
class LMDBDataSet(Dataset):
|
||||||
def __init__(self, config, mode, logger, seed=None):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(LMDBDateSet, self).__init__()
|
super(LMDBDataSet, self).__init__()
|
||||||
|
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
dataset_config = config[mode]['dataset']
|
dataset_config = config[mode]['dataset']
|
||||||
|
|
|
@ -23,11 +23,16 @@ def build_loss(config):
|
||||||
|
|
||||||
# rec loss
|
# rec loss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
|
from .rec_att_loss import AttentionLoss
|
||||||
|
from .rec_srn_loss import SRNLoss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
|
support_dict = [
|
||||||
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
|
'SRNLoss'
|
||||||
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLoss(nn.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(AttentionLoss, self).__init__()
|
||||||
|
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
targets = batch[1].astype("int64")
|
||||||
|
label_lengths = batch[2].astype('int64')
|
||||||
|
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
|
||||||
|
1], predicts.shape[2]
|
||||||
|
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
|
||||||
|
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||||
|
|
||||||
|
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||||
|
targets = paddle.reshape(targets, [-1])
|
||||||
|
|
||||||
|
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
|
|
@ -0,0 +1,47 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLoss(nn.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(SRNLoss, self).__init__()
|
||||||
|
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
predict = predicts['predict']
|
||||||
|
word_predict = predicts['word_out']
|
||||||
|
gsrm_predict = predicts['gsrm_out']
|
||||||
|
label = batch[1]
|
||||||
|
|
||||||
|
casted_label = paddle.cast(x=label, dtype='int64')
|
||||||
|
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
|
||||||
|
|
||||||
|
cost_word = self.loss_func(word_predict, label=casted_label)
|
||||||
|
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
|
||||||
|
cost_vsfd = self.loss_func(predict, label=casted_label)
|
||||||
|
|
||||||
|
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
|
||||||
|
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
||||||
|
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
||||||
|
|
||||||
|
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
|
||||||
|
|
||||||
|
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
|
@ -33,8 +33,6 @@ class RecMetric(object):
|
||||||
if pred == target:
|
if pred == target:
|
||||||
correct_num += 1
|
correct_num += 1
|
||||||
all_num += 1
|
all_num += 1
|
||||||
# if all_num < 10 and kwargs.get('show_str', False):
|
|
||||||
# print('{} -> {}'.format(pred, target))
|
|
||||||
self.correct_num += correct_num
|
self.correct_num += correct_num
|
||||||
self.all_num += all_num
|
self.all_num += all_num
|
||||||
self.norm_edit_dis += norm_edit_dis
|
self.norm_edit_dis += norm_edit_dis
|
||||||
|
@ -50,7 +48,7 @@ class RecMetric(object):
|
||||||
'norm_edit_dis': 0,
|
'norm_edit_dis': 0,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
acc = self.correct_num / self.all_num
|
acc = 1.0 * self.correct_num / self.all_num
|
||||||
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
||||||
self.reset()
|
self.reset()
|
||||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||||
|
|
|
@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
|
||||||
config["Head"]['in_channels'] = in_channels
|
config["Head"]['in_channels'] = in_channels
|
||||||
self.head = build_head(config["Head"])
|
self.head = build_head(config["Head"])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, data=None):
|
||||||
if self.use_transform:
|
if self.use_transform:
|
||||||
x = self.transform(x)
|
x = self.transform(x)
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if self.use_neck:
|
if self.use_neck:
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
|
if data is None:
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
|
else:
|
||||||
|
x = self.head(x, data)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -24,7 +24,8 @@ def build_backbone(config, model_type):
|
||||||
elif model_type == 'rec' or model_type == 'cls':
|
elif model_type == 'rec' or model_type == 'cls':
|
||||||
from .rec_mobilenet_v3 import MobileNetV3
|
from .rec_mobilenet_v3 import MobileNetV3
|
||||||
from .rec_resnet_vd import ResNet
|
from .rec_resnet_vd import ResNet
|
||||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
|
from .rec_resnet_fpn import ResNetFPN
|
||||||
|
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer):
|
||||||
[5, 72, 40, True, 'relu', 2],
|
[5, 72, 40, True, 'relu', 2],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[3, 240, 80, False, 'hard_swish', 2],
|
[3, 240, 80, False, 'hardswish', 2],
|
||||||
[3, 200, 80, False, 'hard_swish', 1],
|
[3, 200, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 480, 112, True, 'hard_swish', 1],
|
[3, 480, 112, True, 'hardswish', 1],
|
||||||
[3, 672, 112, True, 'hard_swish', 1],
|
[3, 672, 112, True, 'hardswish', 1],
|
||||||
[5, 672, 160, True, 'hard_swish', 2],
|
[5, 672, 160, True, 'hardswish', 2],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 960
|
cls_ch_squeeze = 960
|
||||||
elif model_name == "small":
|
elif model_name == "small":
|
||||||
|
@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer):
|
||||||
[3, 16, 16, True, 'relu', 2],
|
[3, 16, 16, True, 'relu', 2],
|
||||||
[3, 72, 24, False, 'relu', 2],
|
[3, 72, 24, False, 'relu', 2],
|
||||||
[3, 88, 24, False, 'relu', 1],
|
[3, 88, 24, False, 'relu', 1],
|
||||||
[5, 96, 40, True, 'hard_swish', 2],
|
[5, 96, 40, True, 'hardswish', 2],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 120, 48, True, 'hard_swish', 1],
|
[5, 120, 48, True, 'hardswish', 1],
|
||||||
[5, 144, 48, True, 'hard_swish', 1],
|
[5, 144, 48, True, 'hardswish', 1],
|
||||||
[5, 288, 96, True, 'hard_swish', 2],
|
[5, 288, 96, True, 'hardswish', 2],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 576
|
cls_ch_squeeze = 576
|
||||||
else:
|
else:
|
||||||
|
@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv1')
|
name='conv1')
|
||||||
|
|
||||||
self.stages = []
|
self.stages = []
|
||||||
|
@ -138,7 +138,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv_last'))
|
name='conv_last'))
|
||||||
self.stages.append(nn.Sequential(*block_list))
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||||
|
@ -192,10 +192,11 @@ class ConvBNLayer(nn.Layer):
|
||||||
if self.if_act:
|
if self.if_act:
|
||||||
if self.act == "relu":
|
if self.act == "relu":
|
||||||
x = F.relu(x)
|
x = F.relu(x)
|
||||||
elif self.act == "hard_swish":
|
elif self.act == "hardswish":
|
||||||
x = F.activation.hard_swish(x)
|
x = F.hardswish(x)
|
||||||
else:
|
else:
|
||||||
print("The activation function is selected incorrectly.")
|
print("The activation function({}) is selected incorrectly.".
|
||||||
|
format(self.act))
|
||||||
exit()
|
exit()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -282,5 +283,5 @@ class SEModule(nn.Layer):
|
||||||
outputs = self.conv1(outputs)
|
outputs = self.conv1(outputs)
|
||||||
outputs = F.relu(outputs)
|
outputs = F.relu(outputs)
|
||||||
outputs = self.conv2(outputs)
|
outputs = self.conv2(outputs)
|
||||||
outputs = F.activation.hard_sigmoid(outputs)
|
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||||
return inputs * outputs
|
return inputs * outputs
|
||||||
|
|
|
@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer):
|
||||||
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[3, 240, 80, False, 'hard_swish', 1],
|
[3, 240, 80, False, 'hardswish', 1],
|
||||||
[3, 200, 80, False, 'hard_swish', 1],
|
[3, 200, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 480, 112, True, 'hard_swish', 1],
|
[3, 480, 112, True, 'hardswish', 1],
|
||||||
[3, 672, 112, True, 'hard_swish', 1],
|
[3, 672, 112, True, 'hardswish', 1],
|
||||||
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
|
[5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 960
|
cls_ch_squeeze = 960
|
||||||
elif model_name == "small":
|
elif model_name == "small":
|
||||||
|
@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer):
|
||||||
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
||||||
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
||||||
[3, 88, 24, False, 'relu', 1],
|
[3, 88, 24, False, 'relu', 1],
|
||||||
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
|
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 120, 48, True, 'hard_swish', 1],
|
[5, 120, 48, True, 'hardswish', 1],
|
||||||
[5, 144, 48, True, 'hard_swish', 1],
|
[5, 144, 48, True, 'hardswish', 1],
|
||||||
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
|
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 576
|
cls_ch_squeeze = 576
|
||||||
else:
|
else:
|
||||||
|
@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv1')
|
name='conv1')
|
||||||
i = 0
|
i = 0
|
||||||
block_list = []
|
block_list = []
|
||||||
|
@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv_last')
|
name='conv_last')
|
||||||
|
|
||||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||||
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ["ResNetFPN"]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetFPN(nn.Layer):
|
||||||
|
def __init__(self, in_channels=1, layers=50, **kwargs):
|
||||||
|
super(ResNetFPN, self).__init__()
|
||||||
|
supported_layers = {
|
||||||
|
18: {
|
||||||
|
'depth': [2, 2, 2, 2],
|
||||||
|
'block_class': BasicBlock
|
||||||
|
},
|
||||||
|
34: {
|
||||||
|
'depth': [3, 4, 6, 3],
|
||||||
|
'block_class': BasicBlock
|
||||||
|
},
|
||||||
|
50: {
|
||||||
|
'depth': [3, 4, 6, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
},
|
||||||
|
101: {
|
||||||
|
'depth': [3, 4, 23, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
},
|
||||||
|
152: {
|
||||||
|
'depth': [3, 8, 36, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
|
||||||
|
num_filters = [64, 128, 256, 512]
|
||||||
|
self.depth = supported_layers[layers]['depth']
|
||||||
|
self.F = []
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
act="relu",
|
||||||
|
name="conv1")
|
||||||
|
self.block_list = []
|
||||||
|
in_ch = 64
|
||||||
|
if layers >= 50:
|
||||||
|
for block in range(len(self.depth)):
|
||||||
|
for i in range(self.depth[block]):
|
||||||
|
if layers in [101, 152] and block == 2:
|
||||||
|
if i == 0:
|
||||||
|
conv_name = "res" + str(block + 2) + "a"
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
block_list = self.add_sublayer(
|
||||||
|
"bottleneckBlock_{}_{}".format(block, i),
|
||||||
|
BottleneckBlock(
|
||||||
|
in_channels=in_ch,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=stride_list[block] if i == 0 else 1,
|
||||||
|
name=conv_name))
|
||||||
|
in_ch = num_filters[block] * 4
|
||||||
|
self.block_list.append(block_list)
|
||||||
|
self.F.append(block_list)
|
||||||
|
else:
|
||||||
|
for block in range(len(self.depth)):
|
||||||
|
for i in range(self.depth[block]):
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
if i == 0 and block != 0:
|
||||||
|
stride = (2, 1)
|
||||||
|
else:
|
||||||
|
stride = (1, 1)
|
||||||
|
basic_block = self.add_sublayer(
|
||||||
|
conv_name,
|
||||||
|
BasicBlock(
|
||||||
|
in_channels=in_ch,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=stride_list[block] if i == 0 else 1,
|
||||||
|
is_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
in_ch = basic_block.out_channels
|
||||||
|
self.block_list.append(basic_block)
|
||||||
|
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
|
||||||
|
self.base_block = []
|
||||||
|
self.conv_trans = []
|
||||||
|
self.bn_block = []
|
||||||
|
for i in [-2, -3]:
|
||||||
|
in_channels = out_ch_list[i + 1] + out_ch_list[i]
|
||||||
|
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_0".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_ch_list[i],
|
||||||
|
kernel_size=1,
|
||||||
|
weight_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_1".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=out_ch_list[i],
|
||||||
|
out_channels=out_ch_list[i],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_2".format(i),
|
||||||
|
nn.BatchNorm(
|
||||||
|
num_channels=out_ch_list[i],
|
||||||
|
act="relu",
|
||||||
|
param_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_3".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=out_ch_list[i],
|
||||||
|
out_channels=512,
|
||||||
|
kernel_size=1,
|
||||||
|
bias_attr=ParamAttr(trainable=True),
|
||||||
|
weight_attr=ParamAttr(trainable=True))))
|
||||||
|
self.out_channels = 512
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
fpn_list = []
|
||||||
|
F = []
|
||||||
|
for i in range(len(self.depth)):
|
||||||
|
fpn_list.append(np.sum(self.depth[:i + 1]))
|
||||||
|
|
||||||
|
for i, block in enumerate(self.block_list):
|
||||||
|
x = block(x)
|
||||||
|
for number in fpn_list:
|
||||||
|
if i + 1 == number:
|
||||||
|
F.append(x)
|
||||||
|
base = F[-1]
|
||||||
|
|
||||||
|
j = 0
|
||||||
|
for i, block in enumerate(self.base_block):
|
||||||
|
if i % 3 == 0 and i < 6:
|
||||||
|
j = j + 1
|
||||||
|
b, c, w, h = F[-j - 1].shape
|
||||||
|
if [w, h] == list(base.shape[2:]):
|
||||||
|
base = base
|
||||||
|
else:
|
||||||
|
base = self.conv_trans[j - 1](base)
|
||||||
|
base = self.bn_block[j - 1](base)
|
||||||
|
base = paddle.concat([base, F[-j - 1]], axis=1)
|
||||||
|
base = block(base)
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2 if stride == (1, 1) else kernel_size,
|
||||||
|
dilation=2 if stride == (1, 1) else 1,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
|
||||||
|
bias_attr=False, )
|
||||||
|
|
||||||
|
if name == "conv1":
|
||||||
|
bn_name = "bn_" + name
|
||||||
|
else:
|
||||||
|
bn_name = "bn" + name[3:]
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name=name + '.output.1.w_0'),
|
||||||
|
bias_attr=ParamAttr(name=name + '.output.1.b_0'),
|
||||||
|
moving_mean_name=bn_name + "_mean",
|
||||||
|
moving_variance_name=bn_name + "_variance")
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ShortCut(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name, is_first=False):
|
||||||
|
super(ShortCut, self).__init__()
|
||||||
|
self.use_conv = True
|
||||||
|
|
||||||
|
if in_channels != out_channels or stride != 1 or is_first == True:
|
||||||
|
if stride == (1, 1):
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels, out_channels, 1, 1, name=name)
|
||||||
|
else: # stride==(2,2)
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels, out_channels, 1, stride, name=name)
|
||||||
|
else:
|
||||||
|
self.use_conv = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2b")
|
||||||
|
|
||||||
|
self.conv2 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2c")
|
||||||
|
|
||||||
|
self.short = ShortCut(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
stride=stride,
|
||||||
|
is_first=False,
|
||||||
|
name=name + "_branch1")
|
||||||
|
self.out_channels = out_channels * 4
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.conv0(x)
|
||||||
|
y = self.conv1(y)
|
||||||
|
y = self.conv2(y)
|
||||||
|
y = y + self.short(x)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name, is_first):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act='relu',
|
||||||
|
stride=stride,
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2b")
|
||||||
|
self.short = ShortCut(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=stride,
|
||||||
|
is_first=is_first,
|
||||||
|
name=name + "_branch1")
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.conv0(x)
|
||||||
|
y = self.conv1(y)
|
||||||
|
y = y + self.short(x)
|
||||||
|
return F.relu(y)
|
|
@ -23,10 +23,15 @@ def build_head(config):
|
||||||
|
|
||||||
# rec head
|
# rec head
|
||||||
from .rec_ctc_head import CTCHead
|
from .rec_ctc_head import CTCHead
|
||||||
|
from .rec_att_head import AttentionHead
|
||||||
|
from .rec_srn_head import SRNHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
|
support_dict = [
|
||||||
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||||
|
'SRNHead'
|
||||||
|
]
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,199 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionHead(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||||
|
super(AttentionHead, self).__init__()
|
||||||
|
self.input_size = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_classes = out_channels
|
||||||
|
|
||||||
|
self.attention_cell = AttentionGRUCell(
|
||||||
|
in_channels, hidden_size, out_channels, use_gru=False)
|
||||||
|
self.generator = nn.Linear(hidden_size, out_channels)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
num_steps = batch_max_length
|
||||||
|
|
||||||
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
|
output_hiddens = []
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets[:, i], onehot_dim=self.num_classes)
|
||||||
|
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
probs = self.generator(output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
probs = None
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets, onehot_dim=self.num_classes)
|
||||||
|
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
probs_step = self.generator(outputs)
|
||||||
|
if probs is None:
|
||||||
|
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||||
|
else:
|
||||||
|
probs = paddle.concat(
|
||||||
|
[probs, paddle.unsqueeze(
|
||||||
|
probs_step, axis=1)], axis=1)
|
||||||
|
next_input = probs_step.argmax(axis=1)
|
||||||
|
targets = next_input
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionGRUCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||||
|
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
||||||
|
return cur_hidden, alpha
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||||
|
super(AttentionLSTM, self).__init__()
|
||||||
|
self.input_size = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_classes = out_channels
|
||||||
|
|
||||||
|
self.attention_cell = AttentionLSTMCell(
|
||||||
|
in_channels, hidden_size, out_channels, use_gru=False)
|
||||||
|
self.generator = nn.Linear(hidden_size, out_channels)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
num_steps = batch_max_length
|
||||||
|
|
||||||
|
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||||
|
(batch_size, self.hidden_size)))
|
||||||
|
output_hiddens = []
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for i in range(num_steps):
|
||||||
|
# one-hot vectors for a i-th char
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets[:, i], onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
probs = self.generator(output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
probs = None
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets, onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
probs_step = self.generator(hidden[0])
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
if probs is None:
|
||||||
|
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||||
|
else:
|
||||||
|
probs = paddle.concat(
|
||||||
|
[probs, paddle.unsqueeze(
|
||||||
|
probs_step, axis=1)], axis=1)
|
||||||
|
|
||||||
|
next_input = probs_step.argmax(axis=1)
|
||||||
|
|
||||||
|
targets = next_input
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTMCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionLSTMCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
if not use_gru:
|
||||||
|
self.rnn = nn.LSTMCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
else:
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
||||||
|
return cur_hidden, alpha
|
|
@ -0,0 +1,279 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
from .self_attention import WrapEncoderForFeature
|
||||||
|
from .self_attention import WrapEncoder
|
||||||
|
from paddle.static import Program
|
||||||
|
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
|
||||||
|
import paddle.fluid.framework as framework
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
gradient_clip = 10
|
||||||
|
|
||||||
|
|
||||||
|
class PVAM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||||
|
num_encoder_tus, hidden_dims):
|
||||||
|
super(PVAM, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_tus
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
# Transformer encoder
|
||||||
|
t = 256
|
||||||
|
c = 512
|
||||||
|
self.wrap_encoder_for_feature = WrapEncoderForFeature(
|
||||||
|
src_vocab_size=1,
|
||||||
|
max_length=t,
|
||||||
|
n_layer=self.num_encoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
# PVAM
|
||||||
|
self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels,
|
||||||
|
out_features=in_channels, )
|
||||||
|
self.emb = paddle.nn.Embedding(
|
||||||
|
num_embeddings=self.max_length, embedding_dim=in_channels)
|
||||||
|
self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels, out_features=1, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
|
||||||
|
b, c, h, w = inputs.shape
|
||||||
|
conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
|
||||||
|
conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
|
||||||
|
# transformer encoder
|
||||||
|
b, t, c = conv_features.shape
|
||||||
|
|
||||||
|
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||||
|
word_features = self.wrap_encoder_for_feature(enc_inputs)
|
||||||
|
|
||||||
|
# pvam
|
||||||
|
b, t, c = word_features.shape
|
||||||
|
word_features = self.fc0(word_features)
|
||||||
|
word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
|
||||||
|
word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
|
||||||
|
word_pos_feature = self.emb(gsrm_word_pos)
|
||||||
|
word_pos_feature_ = paddle.reshape(word_pos_feature,
|
||||||
|
[-1, self.max_length, 1, c])
|
||||||
|
word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
|
||||||
|
y = word_pos_feature_ + word_features_
|
||||||
|
y = F.tanh(y)
|
||||||
|
attention_weight = self.fc1(y)
|
||||||
|
attention_weight = paddle.reshape(
|
||||||
|
attention_weight, shape=[-1, self.max_length, t])
|
||||||
|
attention_weight = F.softmax(attention_weight, axis=-1)
|
||||||
|
pvam_features = paddle.matmul(attention_weight,
|
||||||
|
word_features) #[b, max_length, c]
|
||||||
|
return pvam_features
|
||||||
|
|
||||||
|
|
||||||
|
class GSRM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||||
|
num_encoder_tus, num_decoder_tus, hidden_dims):
|
||||||
|
super(GSRM, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_tus
|
||||||
|
self.num_decoder_TUs = num_decoder_tus
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels, out_features=self.char_num)
|
||||||
|
self.wrap_encoder0 = WrapEncoder(
|
||||||
|
src_vocab_size=self.char_num + 1,
|
||||||
|
max_length=self.max_length,
|
||||||
|
n_layer=self.num_decoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
self.wrap_encoder1 = WrapEncoder(
|
||||||
|
src_vocab_size=self.char_num + 1,
|
||||||
|
max_length=self.max_length,
|
||||||
|
n_layer=self.num_decoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
self.mul = lambda x: paddle.matmul(x=x,
|
||||||
|
y=self.wrap_encoder0.prepare_decoder.emb0.weight,
|
||||||
|
transpose_y=True)
|
||||||
|
|
||||||
|
def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2):
|
||||||
|
# ===== GSRM Visual-to-semantic embedding block =====
|
||||||
|
b, t, c = inputs.shape
|
||||||
|
pvam_features = paddle.reshape(inputs, [-1, c])
|
||||||
|
word_out = self.fc0(pvam_features)
|
||||||
|
word_ids = paddle.argmax(F.softmax(word_out), axis=1)
|
||||||
|
word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
|
||||||
|
|
||||||
|
#===== GSRM Semantic reasoning block =====
|
||||||
|
"""
|
||||||
|
This module is achieved through bi-transformers,
|
||||||
|
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
||||||
|
"""
|
||||||
|
pad_idx = self.char_num
|
||||||
|
|
||||||
|
word1 = paddle.cast(word_ids, "float32")
|
||||||
|
word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
|
||||||
|
word1 = paddle.cast(word1, "int64")
|
||||||
|
word1 = word1[:, :-1, :]
|
||||||
|
word2 = word_ids
|
||||||
|
|
||||||
|
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||||
|
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||||
|
|
||||||
|
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
|
||||||
|
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
|
||||||
|
|
||||||
|
gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
|
||||||
|
value=0.,
|
||||||
|
data_format="NLC")
|
||||||
|
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||||
|
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||||
|
|
||||||
|
gsrm_out = self.mul(gsrm_features)
|
||||||
|
|
||||||
|
b, t, c = gsrm_out.shape
|
||||||
|
gsrm_out = paddle.reshape(gsrm_out, [-1, c])
|
||||||
|
|
||||||
|
return gsrm_features, word_out, gsrm_out
|
||||||
|
|
||||||
|
|
||||||
|
class VSFD(nn.Layer):
|
||||||
|
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
|
||||||
|
super(VSFD, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels * 2, out_features=pvam_ch)
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=pvam_ch, out_features=self.char_num)
|
||||||
|
|
||||||
|
def forward(self, pvam_feature, gsrm_feature):
|
||||||
|
b, t, c1 = pvam_feature.shape
|
||||||
|
b, t, c2 = gsrm_feature.shape
|
||||||
|
combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
|
||||||
|
img_comb_feature_ = paddle.reshape(
|
||||||
|
combine_feature_, shape=[-1, c1 + c2])
|
||||||
|
img_comb_feature_map = self.fc0(img_comb_feature_)
|
||||||
|
img_comb_feature_map = F.sigmoid(img_comb_feature_map)
|
||||||
|
img_comb_feature_map = paddle.reshape(
|
||||||
|
img_comb_feature_map, shape=[-1, t, c1])
|
||||||
|
combine_feature = img_comb_feature_map * pvam_feature + (
|
||||||
|
1.0 - img_comb_feature_map) * gsrm_feature
|
||||||
|
img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
|
||||||
|
|
||||||
|
out = self.fc1(img_comb_feature)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SRNHead(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, max_text_length, num_heads,
|
||||||
|
num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
|
||||||
|
super(SRNHead, self).__init__()
|
||||||
|
self.char_num = out_channels
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_TUs
|
||||||
|
self.num_decoder_TUs = num_decoder_TUs
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
|
||||||
|
self.pvam = PVAM(
|
||||||
|
in_channels=in_channels,
|
||||||
|
char_num=self.char_num,
|
||||||
|
max_text_length=self.max_length,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_encoder_tus=self.num_encoder_TUs,
|
||||||
|
hidden_dims=self.hidden_dims)
|
||||||
|
|
||||||
|
self.gsrm = GSRM(
|
||||||
|
in_channels=in_channels,
|
||||||
|
char_num=self.char_num,
|
||||||
|
max_text_length=self.max_length,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_encoder_tus=self.num_encoder_TUs,
|
||||||
|
num_decoder_tus=self.num_decoder_TUs,
|
||||||
|
hidden_dims=self.hidden_dims)
|
||||||
|
self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
|
||||||
|
|
||||||
|
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||||
|
|
||||||
|
def forward(self, inputs, others):
|
||||||
|
encoder_word_pos = others[0]
|
||||||
|
gsrm_word_pos = others[1]
|
||||||
|
gsrm_slf_attn_bias1 = others[2]
|
||||||
|
gsrm_slf_attn_bias2 = others[3]
|
||||||
|
|
||||||
|
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
|
||||||
|
|
||||||
|
gsrm_feature, word_out, gsrm_out = self.gsrm(
|
||||||
|
pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
|
final_out = self.vsfd(pvam_feature, gsrm_feature)
|
||||||
|
if not self.training:
|
||||||
|
final_out = F.softmax(final_out, axis=1)
|
||||||
|
|
||||||
|
_, decoded_out = paddle.topk(final_out, k=1)
|
||||||
|
|
||||||
|
predicts = OrderedDict([
|
||||||
|
('predict', final_out),
|
||||||
|
('pvam_feature', pvam_feature),
|
||||||
|
('decoded_out', decoded_out),
|
||||||
|
('word_out', word_out),
|
||||||
|
('gsrm_out', gsrm_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
return predicts
|
|
@ -0,0 +1,409 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import ParamAttr, nn
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
gradient_clip = 10
|
||||||
|
|
||||||
|
|
||||||
|
class WrapEncoderForFeature(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
max_length,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd,
|
||||||
|
postprocess_cmd,
|
||||||
|
weight_sharing,
|
||||||
|
bos_idx=0):
|
||||||
|
super(WrapEncoderForFeature, self).__init__()
|
||||||
|
|
||||||
|
self.prepare_encoder = PrepareEncoder(
|
||||||
|
src_vocab_size,
|
||||||
|
d_model,
|
||||||
|
max_length,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
bos_idx=bos_idx,
|
||||||
|
word_emb_param_name="src_word_emb_table")
|
||||||
|
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||||
|
d_inner_hid, prepostprocess_dropout,
|
||||||
|
attention_dropout, relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)
|
||||||
|
|
||||||
|
def forward(self, enc_inputs):
|
||||||
|
conv_features, src_pos, src_slf_attn_bias = enc_inputs
|
||||||
|
enc_input = self.prepare_encoder(conv_features, src_pos)
|
||||||
|
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class WrapEncoder(nn.Layer):
|
||||||
|
"""
|
||||||
|
embedder + encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
max_length,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd,
|
||||||
|
postprocess_cmd,
|
||||||
|
weight_sharing,
|
||||||
|
bos_idx=0):
|
||||||
|
super(WrapEncoder, self).__init__()
|
||||||
|
|
||||||
|
self.prepare_decoder = PrepareDecoder(
|
||||||
|
src_vocab_size,
|
||||||
|
d_model,
|
||||||
|
max_length,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
bos_idx=bos_idx)
|
||||||
|
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||||
|
d_inner_hid, prepostprocess_dropout,
|
||||||
|
attention_dropout, relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)
|
||||||
|
|
||||||
|
def forward(self, enc_inputs):
|
||||||
|
src_word, src_pos, src_slf_attn_bias = enc_inputs
|
||||||
|
enc_input = self.prepare_decoder(src_word, src_pos)
|
||||||
|
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Layer):
|
||||||
|
"""
|
||||||
|
encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da"):
|
||||||
|
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
self.encoder_layers = list()
|
||||||
|
for i in range(n_layer):
|
||||||
|
self.encoder_layers.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"layer_%d" % i,
|
||||||
|
EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
|
||||||
|
prepostprocess_dropout, attention_dropout,
|
||||||
|
relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)))
|
||||||
|
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_input, attn_bias):
|
||||||
|
for encoder_layer in self.encoder_layers:
|
||||||
|
enc_output = encoder_layer(enc_input, attn_bias)
|
||||||
|
enc_input = enc_output
|
||||||
|
enc_output = self.processer(enc_output)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Layer):
|
||||||
|
"""
|
||||||
|
EncoderLayer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da"):
|
||||||
|
|
||||||
|
super(EncoderLayer, self).__init__()
|
||||||
|
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
|
||||||
|
attention_dropout)
|
||||||
|
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
|
||||||
|
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_input, attn_bias):
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
self.preprocesser1(enc_input), None, None, attn_bias)
|
||||||
|
attn_output = self.postprocesser1(attn_output, enc_input)
|
||||||
|
ffn_output = self.ffn(self.preprocesser2(attn_output))
|
||||||
|
ffn_output = self.postprocesser2(ffn_output, attn_output)
|
||||||
|
return ffn_output
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Layer):
|
||||||
|
"""
|
||||||
|
Multi-Head Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
|
||||||
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
self.n_head = n_head
|
||||||
|
self.d_key = d_key
|
||||||
|
self.d_value = d_value
|
||||||
|
self.d_model = d_model
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.q_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||||
|
self.k_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||||
|
self.v_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_value * n_head, bias_attr=False)
|
||||||
|
self.proj_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_value * n_head, out_features=d_model, bias_attr=False)
|
||||||
|
|
||||||
|
def _prepare_qkv(self, queries, keys, values, cache=None):
|
||||||
|
if keys is None: # self-attention
|
||||||
|
keys, values = queries, queries
|
||||||
|
static_kv = False
|
||||||
|
else: # cross-attention
|
||||||
|
static_kv = True
|
||||||
|
|
||||||
|
q = self.q_fc(queries)
|
||||||
|
q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
|
||||||
|
q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
|
if cache is not None and static_kv and "static_k" in cache:
|
||||||
|
# for encoder-decoder attention in inference and has cached
|
||||||
|
k = cache["static_k"]
|
||||||
|
v = cache["static_v"]
|
||||||
|
else:
|
||||||
|
k = self.k_fc(keys)
|
||||||
|
v = self.v_fc(values)
|
||||||
|
k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
|
||||||
|
k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
|
||||||
|
v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
|
||||||
|
v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
if static_kv and not "static_k" in cache:
|
||||||
|
# for encoder-decoder attention in inference and has not cached
|
||||||
|
cache["static_k"], cache["static_v"] = k, v
|
||||||
|
elif not static_kv:
|
||||||
|
# for decoder self-attention in inference
|
||||||
|
cache_k, cache_v = cache["k"], cache["v"]
|
||||||
|
k = paddle.concat([cache_k, k], axis=2)
|
||||||
|
v = paddle.concat([cache_v, v], axis=2)
|
||||||
|
cache["k"], cache["v"] = k, v
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(self, queries, keys, values, attn_bias, cache=None):
|
||||||
|
# compute q ,k ,v
|
||||||
|
keys = queries if keys is None else keys
|
||||||
|
values = keys if values is None else values
|
||||||
|
q, k, v = self._prepare_qkv(queries, keys, values, cache)
|
||||||
|
|
||||||
|
# scale dot product attention
|
||||||
|
product = paddle.matmul(x=q, y=k, transpose_y=True)
|
||||||
|
product = product * self.d_model**-0.5
|
||||||
|
if attn_bias is not None:
|
||||||
|
product += attn_bias
|
||||||
|
weights = F.softmax(product)
|
||||||
|
if self.dropout_rate:
|
||||||
|
weights = F.dropout(
|
||||||
|
weights, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
out = paddle.matmul(weights, v)
|
||||||
|
|
||||||
|
# combine heads
|
||||||
|
out = paddle.transpose(out, perm=[0, 2, 1, 3])
|
||||||
|
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
||||||
|
|
||||||
|
# project to output
|
||||||
|
out = self.proj_fc(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PrePostProcessLayer(nn.Layer):
|
||||||
|
"""
|
||||||
|
PrePostProcessLayer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, process_cmd, d_model, dropout_rate):
|
||||||
|
super(PrePostProcessLayer, self).__init__()
|
||||||
|
self.process_cmd = process_cmd
|
||||||
|
self.functors = []
|
||||||
|
for cmd in self.process_cmd:
|
||||||
|
if cmd == "a": # add residual connection
|
||||||
|
self.functors.append(lambda x, y: x + y if y is not None else x)
|
||||||
|
elif cmd == "n": # add layer normalization
|
||||||
|
self.functors.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"layer_norm_%d" % len(
|
||||||
|
self.sublayers(include_sublayers=False)),
|
||||||
|
paddle.nn.LayerNorm(
|
||||||
|
normalized_shape=d_model,
|
||||||
|
weight_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(1.)),
|
||||||
|
bias_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(0.)))))
|
||||||
|
elif cmd == "d": # add dropout
|
||||||
|
self.functors.append(lambda x: F.dropout(
|
||||||
|
x, p=dropout_rate, mode="downscale_in_infer")
|
||||||
|
if dropout_rate else x)
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
for i, cmd in enumerate(self.process_cmd):
|
||||||
|
if cmd == "a":
|
||||||
|
x = self.functors[i](x, residual)
|
||||||
|
else:
|
||||||
|
x = self.functors[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareEncoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
src_emb_dim,
|
||||||
|
src_max_len,
|
||||||
|
dropout_rate=0,
|
||||||
|
bos_idx=0,
|
||||||
|
word_emb_param_name=None,
|
||||||
|
pos_enc_param_name=None):
|
||||||
|
super(PrepareEncoder, self).__init__()
|
||||||
|
self.src_emb_dim = src_emb_dim
|
||||||
|
self.src_max_len = src_max_len
|
||||||
|
self.emb = paddle.nn.Embedding(
|
||||||
|
num_embeddings=self.src_max_len,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
sparse=True)
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
def forward(self, src_word, src_pos):
|
||||||
|
src_word_emb = src_word
|
||||||
|
src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
|
||||||
|
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||||
|
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||||
|
src_pos_enc = self.emb(src_pos)
|
||||||
|
src_pos_enc.stop_gradient = True
|
||||||
|
enc_input = src_word_emb + src_pos_enc
|
||||||
|
if self.dropout_rate:
|
||||||
|
out = F.dropout(
|
||||||
|
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
else:
|
||||||
|
out = enc_input
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareDecoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
src_emb_dim,
|
||||||
|
src_max_len,
|
||||||
|
dropout_rate=0,
|
||||||
|
bos_idx=0,
|
||||||
|
word_emb_param_name=None,
|
||||||
|
pos_enc_param_name=None):
|
||||||
|
super(PrepareDecoder, self).__init__()
|
||||||
|
self.src_emb_dim = src_emb_dim
|
||||||
|
"""
|
||||||
|
self.emb0 = Embedding(num_embeddings=src_vocab_size,
|
||||||
|
embedding_dim=src_emb_dim)
|
||||||
|
"""
|
||||||
|
self.emb0 = paddle.nn.Embedding(
|
||||||
|
num_embeddings=src_vocab_size,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
padding_idx=bos_idx,
|
||||||
|
weight_attr=paddle.ParamAttr(
|
||||||
|
name=word_emb_param_name,
|
||||||
|
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||||
|
self.emb1 = paddle.nn.Embedding(
|
||||||
|
num_embeddings=src_max_len,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
def forward(self, src_word, src_pos):
|
||||||
|
src_word = fluid.layers.cast(src_word, 'int64')
|
||||||
|
src_word = paddle.squeeze(src_word, axis=-1)
|
||||||
|
src_word_emb = self.emb0(src_word)
|
||||||
|
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||||
|
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||||
|
src_pos_enc = self.emb1(src_pos)
|
||||||
|
src_pos_enc.stop_gradient = True
|
||||||
|
enc_input = src_word_emb + src_pos_enc
|
||||||
|
if self.dropout_rate:
|
||||||
|
out = F.dropout(
|
||||||
|
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
else:
|
||||||
|
out = enc_input
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FFN(nn.Layer):
|
||||||
|
"""
|
||||||
|
Feed-Forward Network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_inner_hid, d_model, dropout_rate):
|
||||||
|
super(FFN, self).__init__()
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_inner_hid)
|
||||||
|
self.fc2 = paddle.nn.Linear(
|
||||||
|
in_features=d_inner_hid, out_features=d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden = self.fc1(x)
|
||||||
|
hidden = F.relu(hidden)
|
||||||
|
if self.dropout_rate:
|
||||||
|
hidden = F.dropout(
|
||||||
|
hidden, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
out = self.fc2(hidden)
|
||||||
|
return out
|
|
@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
|
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -33,6 +33,9 @@ class BaseRecLabelDecode(object):
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
|
@ -109,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
|
|
||||||
preds_idx = preds.argmax(axis=2)
|
preds_idx = preds.argmax(axis=2)
|
||||||
preds_prob = preds.max(axis=2)
|
preds_prob = preds.max(axis=2)
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||||
|
@ -133,16 +135,143 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(AttnLabelDecode, self).__init__(character_dict_path,
|
super(AttnLabelDecode, self).__init__(character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
self.beg_str = "sos"
|
|
||||||
self.end_str = "eos"
|
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def __call__(self, text):
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list)))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
"""
|
||||||
text = self.decode(text)
|
text = self.decode(text)
|
||||||
|
if label is None:
|
||||||
return text
|
return text
|
||||||
|
else:
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
"""
|
||||||
|
if isinstance(preds, paddle.Tensor):
|
||||||
|
preds = preds.numpy()
|
||||||
|
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='en',
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
pred = preds['predict']
|
||||||
|
char_num = len(self.character_str) + 2
|
||||||
|
if isinstance(pred, paddle.Tensor):
|
||||||
|
pred = pred.numpy()
|
||||||
|
pred = np.reshape(pred, [-1, char_num])
|
||||||
|
|
||||||
|
preds_idx = np.argmax(pred, axis=1)
|
||||||
|
preds_prob = np.max(pred, axis=1)
|
||||||
|
|
||||||
|
preds_idx = np.reshape(preds_idx, [-1, 25])
|
||||||
|
|
||||||
|
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||||
|
|
||||||
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list)))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
def get_ignored_tokens(self):
|
def get_ignored_tokens(self):
|
||||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
|
|
@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
|
||||||
from tools.program import load_config, merge_config, ArgsParser
|
from tools.program import load_config, merge_config, ArgsParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", help="configuration file to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--output_path", type=str, default='./output/infer/')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
|
@ -52,16 +60,39 @@ def main():
|
||||||
|
|
||||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||||
|
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
other_shape = [
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 256, 1],
|
||||||
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
|
shape=[None, 25, 1],
|
||||||
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, 25, 25], dtype="int64"),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, 25, 25], dtype="int64")
|
||||||
|
]
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
else:
|
||||||
infer_shape = [3, -1, -1]
|
infer_shape = [3, -1, -1]
|
||||||
if config['Architecture']['model_type'] == "rec":
|
if config['Architecture']['model_type'] == "rec":
|
||||||
infer_shape = [3, 32, -1]
|
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||||
|
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||||
|
'Transform'] is not None and config['Architecture'][
|
||||||
|
'Transform']['name'] == 'TPS':
|
||||||
|
logger.info(
|
||||||
|
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
||||||
|
)
|
||||||
|
infer_shape[-1] = 100
|
||||||
model = to_static(
|
model = to_static(
|
||||||
model,
|
model,
|
||||||
input_spec=[
|
input_spec=[
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None] + infer_shape, dtype='float32')
|
shape=[None] + infer_shape, dtype='float32')
|
||||||
])
|
])
|
||||||
|
|
||||||
paddle.jit.save(model, save_path)
|
paddle.jit.save(model, save_path)
|
||||||
logger.info('inference model is saved to {}'.format(save_path))
|
logger.info('inference model is saved to {}'.format(save_path))
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import paddle
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
|
@ -46,6 +47,13 @@ class TextRecognizer(object):
|
||||||
"character_dict_path": args.rec_char_dict_path,
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
|
if self.rec_algorithm == "SRN":
|
||||||
|
postprocess_params = {
|
||||||
|
'name': 'SRNLabelDecode',
|
||||||
|
"character_type": args.rec_char_type,
|
||||||
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
|
"use_space_char": args.use_space_char
|
||||||
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
|
@ -70,6 +78,78 @@ class TextRecognizer(object):
|
||||||
padding_im[:, :, 0:resized_w] = resized_image
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
return padding_im
|
return padding_im
|
||||||
|
|
||||||
|
def resize_norm_img_srn(self, img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(
|
||||||
|
gsrm_slf_attn_bias1,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(
|
||||||
|
gsrm_slf_attn_bias2,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||||
|
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||||
|
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||||
|
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||||
|
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||||
|
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||||
|
|
||||||
|
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
# Calculate the aspect ratio of all text bars
|
# Calculate the aspect ratio of all text bars
|
||||||
|
@ -93,21 +173,64 @@ class TextRecognizer(object):
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
if self.rec_algorithm != "SRN":
|
||||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||||
max_wh_ratio)
|
max_wh_ratio)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
norm_img_batch.append(norm_img)
|
norm_img_batch.append(norm_img)
|
||||||
|
else:
|
||||||
|
norm_img = self.process_image_srn(
|
||||||
|
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||||
|
encoder_word_pos_list = []
|
||||||
|
gsrm_word_pos_list = []
|
||||||
|
gsrm_slf_attn_bias1_list = []
|
||||||
|
gsrm_slf_attn_bias2_list = []
|
||||||
|
encoder_word_pos_list.append(norm_img[1])
|
||||||
|
gsrm_word_pos_list.append(norm_img[2])
|
||||||
|
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||||
|
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||||
|
norm_img_batch.append(norm_img[0])
|
||||||
norm_img_batch = np.concatenate(norm_img_batch)
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
|
|
||||||
|
if self.rec_algorithm == "SRN":
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||||
|
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||||
|
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||||
|
gsrm_slf_attn_bias1_list)
|
||||||
|
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||||
|
gsrm_slf_attn_bias2_list)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
norm_img_batch,
|
||||||
|
encoder_word_pos_list,
|
||||||
|
gsrm_word_pos_list,
|
||||||
|
gsrm_slf_attn_bias1_list,
|
||||||
|
gsrm_slf_attn_bias2_list,
|
||||||
|
]
|
||||||
|
input_names = self.predictor.get_input_names()
|
||||||
|
for i in range(len(input_names)):
|
||||||
|
input_tensor = self.predictor.get_input_handle(input_names[
|
||||||
|
i])
|
||||||
|
input_tensor.copy_from_cpu(inputs[i])
|
||||||
self.predictor.run()
|
self.predictor.run()
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
preds = {"predict": outputs[2]}
|
||||||
|
else:
|
||||||
|
starttime = time.time()
|
||||||
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
|
self.predictor.run()
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for output_tensor in self.output_tensors:
|
for output_tensor in self.output_tensors:
|
||||||
output = output_tensor.copy_to_cpu()
|
output = output_tensor.copy_to_cpu()
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
preds = outputs[0]
|
preds = outputs[0]
|
||||||
|
|
||||||
rec_result = self.postprocess_op(preds)
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
|
|
|
@ -62,6 +62,12 @@ def main():
|
||||||
elif op_name in ['RecResizeImg']:
|
elif op_name in ['RecResizeImg']:
|
||||||
op[op_name]['infer_mode'] = True
|
op[op_name]['infer_mode'] = True
|
||||||
elif op_name == 'KeepKeys':
|
elif op_name == 'KeepKeys':
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
op[op_name]['keep_keys'] = [
|
||||||
|
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||||
|
]
|
||||||
|
else:
|
||||||
op[op_name]['keep_keys'] = ['image']
|
op[op_name]['keep_keys'] = ['image']
|
||||||
transforms.append(op)
|
transforms.append(op)
|
||||||
global_config['infer_mode'] = True
|
global_config['infer_mode'] = True
|
||||||
|
@ -74,9 +80,24 @@ def main():
|
||||||
img = f.read()
|
img = f.read()
|
||||||
data = {'image': img}
|
data = {'image': img}
|
||||||
batch = transform(data, ops)
|
batch = transform(data, ops)
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
||||||
|
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
||||||
|
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
||||||
|
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
||||||
|
|
||||||
|
others = [
|
||||||
|
paddle.to_tensor(encoder_word_pos_list),
|
||||||
|
paddle.to_tensor(gsrm_word_pos_list),
|
||||||
|
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||||
|
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||||
|
]
|
||||||
|
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
preds = model(images, others)
|
||||||
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
post_result = post_process_class(preds)
|
post_result = post_process_class(preds)
|
||||||
for rec_reuslt in post_result:
|
for rec_reuslt in post_result:
|
||||||
|
|
|
@ -174,6 +174,7 @@ def train(config,
|
||||||
best_model_dict = {main_indicator: 0}
|
best_model_dict = {main_indicator: 0}
|
||||||
best_model_dict.update(pre_best_model_dict)
|
best_model_dict.update(pre_best_model_dict)
|
||||||
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
||||||
|
model_average = False
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
|
@ -194,6 +195,11 @@ def train(config,
|
||||||
break
|
break
|
||||||
lr = optimizer.get_lr()
|
lr = optimizer.get_lr()
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
others = batch[-4:]
|
||||||
|
preds = model(images, others)
|
||||||
|
model_average = True
|
||||||
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
avg_loss = loss['loss']
|
avg_loss = loss['loss']
|
||||||
|
@ -238,6 +244,13 @@ def train(config,
|
||||||
# eval
|
# eval
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
|
if model_average:
|
||||||
|
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
||||||
|
0.15,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
min_average_window=10000,
|
||||||
|
max_average_window=15625)
|
||||||
|
Model_Average.apply()
|
||||||
cur_metric = eval(model, valid_dataloader, post_process_class,
|
cur_metric = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
|
@ -273,6 +286,7 @@ def train(config,
|
||||||
best_model_dict[main_indicator],
|
best_model_dict[main_indicator],
|
||||||
global_step)
|
global_step)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
optimizer.clear_grad()
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
save_model(
|
save_model(
|
||||||
|
@ -313,6 +327,10 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
break
|
break
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
if "SRN" in str(model.head):
|
||||||
|
others = batch[-4:]
|
||||||
|
preds = model(images, others)
|
||||||
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
|
|
5
train.sh
5
train.sh
|
@ -1,5 +1,2 @@
|
||||||
# for paddle.__version__ >= 2.0rc1
|
# recommended paddle.__version__ == 2.0.0
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||||
|
|
||||||
# for paddle.__version__ < 2.0rc1
|
|
||||||
# python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
|
||||||
|
|
Loading…
Reference in New Issue