Merge remote-tracking branch 'upstream/dygraph' into dy1
This commit is contained in:
commit
f20f6d2d27
|
@ -1031,7 +1031,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
|
|
||||||
for box in self.result_dic:
|
for box in self.result_dic:
|
||||||
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
|
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
|
||||||
if trans_dic["label"] is "" and mode == 'Auto':
|
if trans_dic["label"] == "" and mode == 'Auto':
|
||||||
continue
|
continue
|
||||||
shapes.append(trans_dic)
|
shapes.append(trans_dic)
|
||||||
|
|
||||||
|
@ -1763,7 +1763,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
QMessageBox.information(self, "Information", msg)
|
QMessageBox.information(self, "Information", msg)
|
||||||
return
|
return
|
||||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||||
if result[0][0] is not '':
|
if result[0][0] != '':
|
||||||
result.insert(0, box)
|
result.insert(0, box)
|
||||||
print('result in reRec is ', result)
|
print('result in reRec is ', result)
|
||||||
self.result_dic.append(result)
|
self.result_dic.append(result)
|
||||||
|
@ -1794,7 +1794,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
QMessageBox.information(self, "Information", msg)
|
QMessageBox.information(self, "Information", msg)
|
||||||
return
|
return
|
||||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||||
if result[0][0] is not '':
|
if result[0][0] != '':
|
||||||
result.insert(0, box)
|
result.insert(0, box)
|
||||||
print('result in reRec is ', result)
|
print('result in reRec is ', result)
|
||||||
if result[1][0] == shape.label:
|
if result[1][0] == shape.label:
|
||||||
|
@ -1999,7 +1999,7 @@ if __name__ == '__main__':
|
||||||
resource_file = './libs/resources.py'
|
resource_file = './libs/resources.py'
|
||||||
if not os.path.exists(resource_file):
|
if not os.path.exists(resource_file):
|
||||||
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
|
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
|
||||||
assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
|
assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
|
||||||
"directory resources.py "
|
"directory resources.py "
|
||||||
import libs.resources
|
import libs.resources
|
||||||
sys.exit(main())
|
sys.exit(main())
|
||||||
|
|
|
@ -5,10 +5,11 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
|
||||||
|
|
||||||
## Notice
|
## Notice
|
||||||
PaddleOCR supports both dynamic graph and static graph programming paradigm
|
PaddleOCR supports both dynamic graph and static graph programming paradigm
|
||||||
- Dynamic graph: dygraph branch (default), **supported by paddle 2.0rc1+ ([installation](./doc/doc_en/installation_en.md))**
|
- Dynamic graph: dygraph branch (default), **supported by paddle 2.0.0 ([installation](./doc/doc_en/installation_en.md))**
|
||||||
- Static graph: develop branch
|
- Static graph: develop branch
|
||||||
|
|
||||||
**Recent updates**
|
**Recent updates**
|
||||||
|
- 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
|
||||||
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](./StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
|
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](./StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
|
||||||
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](./PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
|
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](./PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
|
||||||
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
|
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
|
||||||
|
|
|
@ -4,11 +4,13 @@
|
||||||
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
|
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
|
||||||
## 注意
|
## 注意
|
||||||
PaddleOCR同时支持动态图与静态图两种编程范式
|
PaddleOCR同时支持动态图与静态图两种编程范式
|
||||||
- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0rc1+([快速安装](./doc/doc_ch/installation.md))
|
- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0.0([快速安装](./doc/doc_ch/installation.md))
|
||||||
- 静态图版本:develop分支
|
- 静态图版本:develop分支
|
||||||
|
|
||||||
**近期更新**
|
**近期更新**
|
||||||
- 2021.1.18 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数152个,每周一都会更新,欢迎大家持续关注。
|
- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802)
|
||||||
|
- 2021.1.25 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数157个,每周一都会更新,欢迎大家持续关注。
|
||||||
|
- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
|
||||||
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
|
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
|
||||||
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
|
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
|
||||||
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
|
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
|
||||||
|
|
|
@ -72,7 +72,7 @@ fusion_generator:
|
||||||
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
||||||
```
|
```
|
||||||
|
|
||||||
* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English, Simplified Chinese and Korean.
|
* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
|
||||||
* Note 2: Synth-Text is mainly used to generate images for OCR recognition models.
|
* Note 2: Synth-Text is mainly used to generate images for OCR recognition models.
|
||||||
So the height of style images should be around 32 pixels. Images in other sizes may behave poorly.
|
So the height of style images should be around 32 pixels. Images in other sizes may behave poorly.
|
||||||
* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
|
* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
|
||||||
|
@ -120,7 +120,7 @@ In actual application scenarios, it is often necessary to synthesize pictures in
|
||||||
* `with_label`:Whether the `label_file` is label file list.
|
* `with_label`:Whether the `label_file` is label file list.
|
||||||
* `CorpusGenerator`:
|
* `CorpusGenerator`:
|
||||||
* `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`.
|
* `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`.
|
||||||
* `language`:Language of the corpus.
|
* `language`:Language of the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
|
||||||
* `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time.
|
* `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time.
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -63,10 +63,10 @@ fusion_generator:
|
||||||
```python
|
```python
|
||||||
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
||||||
```
|
```
|
||||||
* 注1:语言选项和语料相对应,目前该工具只支持英文、简体中文和韩语。
|
* 注1:语言选项和语料相对应,目前支持英文(en)、简体中文(ch)和韩语(ko)。
|
||||||
* 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。
|
* 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。
|
||||||
如果输入图像尺寸相差过多,效果可能不佳。
|
如果输入图像尺寸相差过多,效果可能不佳。
|
||||||
* 注3:可以通过修改配置文件中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
|
* 注3:可以通过修改配置文件`configs/config.yml`中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
|
||||||
|
|
||||||
|
|
||||||
例如,输入如下图片和语料"PaddleOCR":
|
例如,输入如下图片和语料"PaddleOCR":
|
||||||
|
@ -105,7 +105,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_
|
||||||
* `with_label`:标志`label_file`是否为label文件。
|
* `with_label`:标志`label_file`是否为label文件。
|
||||||
* `CorpusGenerator`:
|
* `CorpusGenerator`:
|
||||||
* `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
|
* `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
|
||||||
* `language`:语料的语种;
|
* `language`:语料的语种,目前支持英文(en)、简体中文(ch)和韩语(ko);
|
||||||
* `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。
|
* `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。
|
||||||
|
|
||||||
语料文件格式示例:
|
语料文件格式示例:
|
||||||
|
|
|
@ -16,7 +16,7 @@ Global:
|
||||||
infer_img:
|
infer_img:
|
||||||
# for data or label process
|
# for data or label process
|
||||||
character_dict_path: ppocr/utils/dict/en_dict.txt
|
character_dict_path: ppocr/utils/dict/en_dict.txt
|
||||||
character_type: ch
|
character_type: EN
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
use_space_char: False
|
use_space_char: False
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
Global:
|
Global:
|
||||||
use_gpu: true
|
use_gpu: True
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -59,7 +59,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -78,7 +78,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -77,7 +77,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -63,7 +63,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -82,7 +82,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -77,7 +77,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -56,7 +56,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -75,7 +75,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -62,7 +62,7 @@ Metric:
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/training/
|
data_dir: ./train_data/data_lmdb_release/training/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -81,7 +81,7 @@ Train:
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: LMDBDateSet
|
name: LMDBDataSet
|
||||||
data_dir: ./train_data/data_lmdb_release/validation/
|
data_dir: ./train_data/data_lmdb_release/validation/
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 72
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 5
|
||||||
|
save_model_dir: ./output/rec/srn_new
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 5000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
num_heads: 8
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
clip_norm: 10.0
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: SRN
|
||||||
|
in_channels: 1
|
||||||
|
Transform:
|
||||||
|
Backbone:
|
||||||
|
name: ResNetFPN
|
||||||
|
Head:
|
||||||
|
name: SRNHead
|
||||||
|
max_text_length: 25
|
||||||
|
num_heads: 8
|
||||||
|
num_encoder_TUs: 2
|
||||||
|
num_decoder_TUs: 4
|
||||||
|
hidden_dims: 512
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: SRNLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: SRNLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/srn_train_data_duiqi
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SRNLabelEncode: # Class handling label
|
||||||
|
- SRNRecResizeImg:
|
||||||
|
image_shape: [1, 64, 256]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image',
|
||||||
|
'label',
|
||||||
|
'length',
|
||||||
|
'encoder_word_pos',
|
||||||
|
'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1',
|
||||||
|
'gsrm_slf_attn_bias2'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
batch_size_per_card: 64
|
||||||
|
drop_last: False
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDataSet
|
||||||
|
data_dir: ./train_data/data_lmdb_release/evaluation
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- SRNLabelEncode: # Class handling label
|
||||||
|
- SRNRecResizeImg:
|
||||||
|
image_shape: [1, 64, 256]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image',
|
||||||
|
'label',
|
||||||
|
'length',
|
||||||
|
'encoder_word_pos',
|
||||||
|
'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1',
|
||||||
|
'gsrm_slf_attn_bias2']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 32
|
||||||
|
num_workers: 4
|
|
@ -42,7 +42,7 @@ python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global
|
||||||
# 比如下载提供的训练模型
|
# 比如下载提供的训练模型
|
||||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
|
||||||
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
|
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
|
||||||
python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
|
python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_inference_dir=./output/quant_inference_model
|
||||||
|
|
||||||
```
|
```
|
||||||
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
|
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
|
||||||
|
|
|
@ -58,7 +58,7 @@ python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global
|
||||||
After getting the model after pruning and finetuning we, can export it as inference_model for predictive deployment:
|
After getting the model after pruning and finetuning we, can export it as inference_model for predictive deployment:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_inference_model
|
python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model
|
||||||
```
|
```
|
||||||
|
|
||||||
### 5. Deploy
|
### 5. Deploy
|
||||||
|
|
|
@ -9,42 +9,43 @@
|
||||||
|
|
||||||
## PaddleOCR常见问题汇总(持续更新)
|
## PaddleOCR常见问题汇总(持续更新)
|
||||||
|
|
||||||
* [近期更新(2021.1.18)](#近期更新)
|
* [近期更新(2021.1.25)](#近期更新)
|
||||||
* [【精选】OCR精选10个问题](#OCR精选10个问题)
|
* [【精选】OCR精选10个问题](#OCR精选10个问题)
|
||||||
* [【理论篇】OCR通用32个问题](#OCR通用问题)
|
* [【理论篇】OCR通用32个问题](#OCR通用问题)
|
||||||
* [基础知识7题](#基础知识)
|
* [基础知识7题](#基础知识)
|
||||||
* [数据集7题](#数据集2)
|
* [数据集7题](#数据集2)
|
||||||
* [模型训练调优18题](#模型训练调优2)
|
* [模型训练调优18题](#模型训练调优2)
|
||||||
* [【实战篇】PaddleOCR实战110个问题](#PaddleOCR实战问题)
|
* [【实战篇】PaddleOCR实战115个问题](#PaddleOCR实战问题)
|
||||||
* [使用咨询36题](#使用咨询)
|
* [使用咨询38题](#使用咨询)
|
||||||
* [数据集17题](#数据集3)
|
* [数据集17题](#数据集3)
|
||||||
* [模型训练调优28题](#模型训练调优3)
|
* [模型训练调优28题](#模型训练调优3)
|
||||||
* [预测部署29题](#预测部署3)
|
* [预测部署32题](#预测部署3)
|
||||||
|
|
||||||
|
|
||||||
<a name="近期更新"></a>
|
<a name="近期更新"></a>
|
||||||
## 近期更新(2021.1.18)
|
## 近期更新(2021.1.25)
|
||||||
|
|
||||||
|
#### Q3.1.37: 小语种模型只有识别模型,没有检测模型吗?
|
||||||
|
|
||||||
#### Q2.3.18: 在PP-OCR系统中,文本检测的骨干网络为什么没有使用SE模块?
|
**A**:小语种(包括纯英文数字)的检测模型和中文的检测模型是共用的,在训练中文检测模型时加入了多语言数据。https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/models_list_en.md#1-text-detection-model。
|
||||||
|
|
||||||
**A**:SE模块是MobileNetV3网络一个重要模块,目的是估计特征图每个特征通道重要性,给特征图每个特征分配权重,提高网络的表达能力。但是,对于文本检测,输入网络的分辨率比较大,一般是640\*640,利用SE模块估计特征图每个特征通道重要性比较困难,网络提升能力有限,但是该模块又比较耗时,因此在PP-OCR系统中,文本检测的骨干网络没有使用SE模块。实验也表明,当去掉SE模块,超轻量模型大小可以减小40%,文本检测效果基本不受影响。详细可以参考PP-OCR技术文章,https://arxiv.org/abs/2009.09941.
|
#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
|
||||||
|
|
||||||
#### Q3.3.27: PaddleOCR关于文本识别模型的训练,支持的数据增强方式有哪些?
|
**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0。
|
||||||
|
|
||||||
**A**:文本识别支持的数据增强方式有随机小幅度裁剪、图像平衡、添加白噪声、颜色漂移、图像反色和Text Image Augmentation(TIA)变换等。可以参考[代码](../../ppocr/data/imaug/rec_img_aug.py)中的warp函数。
|
#### Q3.4.30: PaddleOCR是否支持在华为鲲鹏920CPU上部署?
|
||||||
|
|
||||||
#### Q3.3.28: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置?
|
**A**:目前Paddle的预测库是支持华为鲲鹏920CPU的,但是OCR还没在这些芯片上测试过,可以自己调试,有问题反馈给我们。
|
||||||
|
|
||||||
**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
|
#### Q3.4.31: 采用Paddle-Lite进行端侧部署,出现问题,环境没问题。
|
||||||
|
|
||||||
#### Q3.4.28: PP-OCR系统中,文本检测的结果有置信度吗?
|
**A**:如果你的预测库是自己编译的,那么你的nb文件也要自己编译,用同一个lite版本。不能直接用下载的nb文件,因为版本不同。
|
||||||
|
|
||||||
**A**:文本检测的结果有置信度,由于推理过程中没有使用,所以没有显示的返回到最终结果中。如果需要文本检测结果的置信度,可以在[文本检测DB的后处理代码](../../ppocr/postprocess/db_postprocess.py)的155行,添加scores信息。这样,在[检测预测代码](../../tools/infer/predict_det.py)的197行,就可以拿到文本检测的scores信息。
|
#### Q3.4.32: PaddleOCR的模型支持onnx转换吗?
|
||||||
|
|
||||||
#### Q3.4.29: DB文本检测,特征提取网络金字塔构建的部分代码在哪儿?
|
**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
|
||||||
|
Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
|
||||||
**A**:特征提取网络金字塔构建的部分:[代码位置](../../ppocr/modeling/necks/db_fpn.py)。ppocr/modeling文件夹里面是组网相关的代码,其中architectures是文本检测或者文本识别整体流程代码;backbones是骨干网络相关代码;necks是类似与FPN的颈函数代码;heads是提取文本检测或者文本识别预测结果相关的头函数;transforms是类似于TPS特征预处理模块。更多的信息可以参考[代码组织结构](./tree.md)。
|
Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
|
||||||
|
|
||||||
<a name="OCR精选10个问题"></a>
|
<a name="OCR精选10个问题"></a>
|
||||||
## 【精选】OCR精选10个问题
|
## 【精选】OCR精选10个问题
|
||||||
|
@ -396,13 +397,13 @@
|
||||||
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
|
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
|
||||||
|
|
||||||
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
|
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
|
||||||
**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0rc1的版本,安装方式为
|
**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0的版本,安装方式为
|
||||||
```
|
```
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
|
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
|
||||||
**A**:这个问题是glibc版本不足导致的,Paddle2.0rc1版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
|
**A**:这个问题是glibc版本不足导致的,Paddle2.0.0版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
|
||||||
如果您的环境不满足这个要求,或者使用的docker镜像为:
|
如果您的环境不满足这个要求,或者使用的docker镜像为:
|
||||||
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
|
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
|
||||||
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`。
|
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`。
|
||||||
|
@ -414,7 +415,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
|
||||||
|
|
||||||
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
|
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
|
||||||
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
|
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
|
||||||
- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0rc1版本,目前仍在开发中。
|
- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0.0版本。
|
||||||
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
|
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
|
||||||
|
|
||||||
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
|
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
|
||||||
|
@ -431,7 +432,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
|
||||||
|
|
||||||
#### Q3.1.27: 如何可视化acc,loss曲线图,模型网络结构图等?
|
#### Q3.1.27: 如何可视化acc,loss曲线图,模型网络结构图等?
|
||||||
|
|
||||||
**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/guides/03_VisualDL/visualdl.html)。
|
**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/03_VisualDL/visualdl.html)。
|
||||||
|
|
||||||
#### Q3.1.28: 在使用StyleText数据合成工具的时候,报错`ModuleNotFoundError: No module named 'utils.config'`,这是为什么呢?
|
#### Q3.1.28: 在使用StyleText数据合成工具的时候,报错`ModuleNotFoundError: No module named 'utils.config'`,这是为什么呢?
|
||||||
|
|
||||||
|
@ -450,7 +451,7 @@ https://github.com/PaddlePaddle/PaddleOCR/blob/de3e2e7cd3b8b65ee02d7a41e570fa5b5
|
||||||
|
|
||||||
#### Q3.1.31: 怎么输出网络结构以及每层的参数信息?
|
#### Q3.1.31: 怎么输出网络结构以及每层的参数信息?
|
||||||
|
|
||||||
**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/hapi/model_summary/summary_cn.html#summary。
|
**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model_summary/summary_cn.html。
|
||||||
|
|
||||||
#### Q3.1.32 能否修改StyleText配置文件中的分辨率?
|
#### Q3.1.32 能否修改StyleText配置文件中的分辨率?
|
||||||
|
|
||||||
|
@ -474,9 +475,18 @@ StyleText的用途主要是:提取style_image中的字体、背景等style信
|
||||||
例如识别身份证照片,可以先匹配"姓名","性别"等关键字,根据这些关键字的坐标去推测其他信息的位置,再与识别的结果匹配。
|
例如识别身份证照片,可以先匹配"姓名","性别"等关键字,根据这些关键字的坐标去推测其他信息的位置,再与识别的结果匹配。
|
||||||
|
|
||||||
#### Q3.1.36 如何识别竹简上的古文?
|
#### Q3.1.36 如何识别竹简上的古文?
|
||||||
|
|
||||||
**A**:对于字符都是普通的汉字字符的情况,只要标注足够的数据,finetune模型就可以了。如果数据量不足,您可以尝试StyleText工具。
|
**A**:对于字符都是普通的汉字字符的情况,只要标注足够的数据,finetune模型就可以了。如果数据量不足,您可以尝试StyleText工具。
|
||||||
而如果使用的字符是特殊的古文字、甲骨文、象形文字等,那么首先需要构建一个古文字的字典,之后再进行训练。
|
而如果使用的字符是特殊的古文字、甲骨文、象形文字等,那么首先需要构建一个古文字的字典,之后再进行训练。
|
||||||
|
|
||||||
|
#### Q3.1.37: 小语种模型只有识别模型,没有检测模型吗?
|
||||||
|
|
||||||
|
**A**:小语种(包括纯英文数字)的检测模型和中文的检测模型是共用的,在训练中文检测模型时加入了多语言数据。https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/models_list_en.md#1-text-detection-model。
|
||||||
|
|
||||||
|
#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
|
||||||
|
|
||||||
|
**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0。
|
||||||
|
|
||||||
<a name="数据集3"></a>
|
<a name="数据集3"></a>
|
||||||
### 数据集
|
### 数据集
|
||||||
|
|
||||||
|
@ -854,3 +864,17 @@ img = cv.imdecode(img_array, -1)
|
||||||
#### Q3.4.29: DB文本检测,特征提取网络金字塔构建的部分代码在哪儿?
|
#### Q3.4.29: DB文本检测,特征提取网络金字塔构建的部分代码在哪儿?
|
||||||
|
|
||||||
**A**:特征提取网络金字塔构建的部分:[代码位置](../../ppocr/modeling/necks/db_fpn.py)。ppocr/modeling文件夹里面是组网相关的代码,其中architectures是文本检测或者文本识别整体流程代码;backbones是骨干网络相关代码;necks是类似与FPN的颈函数代码;heads是提取文本检测或者文本识别预测结果相关的头函数;transforms是类似于TPS特征预处理模块。更多的信息可以参考[代码组织结构](./tree.md)。
|
**A**:特征提取网络金字塔构建的部分:[代码位置](../../ppocr/modeling/necks/db_fpn.py)。ppocr/modeling文件夹里面是组网相关的代码,其中architectures是文本检测或者文本识别整体流程代码;backbones是骨干网络相关代码;necks是类似与FPN的颈函数代码;heads是提取文本检测或者文本识别预测结果相关的头函数;transforms是类似于TPS特征预处理模块。更多的信息可以参考[代码组织结构](./tree.md)。
|
||||||
|
|
||||||
|
#### Q3.4.30: PaddleOCR是否支持在华为鲲鹏920CPU上部署?
|
||||||
|
|
||||||
|
**A**:目前Paddle的预测库是支持华为鲲鹏920CPU的,但是OCR还没在这些芯片上测试过,可以自己调试,有问题反馈给我们。
|
||||||
|
|
||||||
|
#### Q3.4.31: 采用Paddle-Lite进行端侧部署,出现问题,环境没问题。
|
||||||
|
|
||||||
|
**A**:如果你的预测库是自己编译的,那么你的nb文件也要自己编译,用同一个lite版本。不能直接用下载的nb文件,因为版本不同。
|
||||||
|
|
||||||
|
#### Q3.4.32: PaddleOCR的模型支持onnx转换吗?
|
||||||
|
|
||||||
|
**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
|
||||||
|
Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
|
||||||
|
Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
|
||||||
|
|
|
@ -63,7 +63,7 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
|
||||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||||
|
|
||||||
```
|
```
|
||||||
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
|
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号。
|
||||||
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
|
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||||
```
|
```
|
||||||
|
|
|
@ -76,7 +76,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
|
||||||
# 单机单卡训练 mv3_db 模型
|
# 单机单卡训练 mv3_db 模型
|
||||||
python3 tools/train.py -c configs/det/det_mv3_db.yml \
|
python3 tools/train.py -c configs/det/det_mv3_db.yml \
|
||||||
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID;如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
|
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
|
||||||
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
|
||||||
```
|
```
|
||||||
|
|
|
@ -306,10 +306,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
||||||
<a name="多语言模型的推理"></a>
|
<a name="多语言模型的推理"></a>
|
||||||
### 4. 多语言模型的推理
|
### 4. 多语言模型的推理
|
||||||
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
|
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
|
||||||
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/` 路径下有默认提供的小语种字体,例如韩文识别:
|
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
|
||||||
|
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/korean.ttf"
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
|
||||||
```
|
```
|
||||||
![](../imgs_words/korean/1.jpg)
|
![](../imgs_words/korean/1.jpg)
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
|
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
|
||||||
PaddleOCR 工作环境
|
PaddleOCR 工作环境
|
||||||
- PaddlePaddle 1.8+ ,推荐使用 PaddlePaddle 2.0rc1
|
- PaddlePaddle 2.0.0
|
||||||
- python3.7
|
- python3.7
|
||||||
- glibc 2.23
|
- glibc 2.23
|
||||||
- cuDNN 7.6+ (GPU)
|
- cuDNN 7.6+ (GPU)
|
||||||
|
@ -35,11 +35,11 @@ sudo docker container exec -it ppocr /bin/bash
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
|
|
||||||
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
|
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
如果您的机器是CPU,请运行以下命令安装
|
如果您的机器是CPU,请运行以下命令安装
|
||||||
|
|
||||||
python3 -m pip install paddlepaddle==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||||
```
|
```
|
||||||
|
|
|
@ -195,8 +195,6 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
||||||
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: |
|
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: |
|
||||||
| [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml) | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
|
| [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml) | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
|
||||||
| [rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml) | CRNN | ResNet34_vd | None | BiLSTM | ctc |
|
| [rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml) | CRNN | ResNet34_vd | None | BiLSTM | ctc |
|
||||||
| rec_chinese_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
|
|
||||||
| rec_chinese_common_train.yml | CRNN | ResNet34_vd | None | BiLSTM | ctc |
|
|
||||||
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||||
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||||
|
@ -272,16 +270,109 @@ Eval:
|
||||||
<a name="小语种"></a>
|
<a name="小语种"></a>
|
||||||
- 小语种
|
- 小语种
|
||||||
|
|
||||||
PaddleOCR也提供了多语言的, `configs/rec/multi_languages` 路径下的提供了多语言的配置文件,目前PaddleOCR支持的多语言算法有:
|
PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
|
||||||
|
|
||||||
| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
|
您有两种方式创建所需的配置文件:
|
||||||
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
|
|
||||||
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语 |
|
|
||||||
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
|
|
||||||
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
|
|
||||||
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
|
|
||||||
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
|
|
||||||
|
|
||||||
|
1. 通过脚本自动生成
|
||||||
|
|
||||||
|
[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) 可以帮助您生成多语言模型的配置文件
|
||||||
|
|
||||||
|
- 以意大利语为例,如果您的数据是按如下格式准备的:
|
||||||
|
```
|
||||||
|
|-train_data
|
||||||
|
|- it_train.txt # 训练集标签
|
||||||
|
|- it_val.txt # 验证集标签
|
||||||
|
|- data
|
||||||
|
|- word_001.jpg
|
||||||
|
|- word_002.jpg
|
||||||
|
|- word_003.jpg
|
||||||
|
| ...
|
||||||
|
```
|
||||||
|
|
||||||
|
可以使用默认参数,生成配置文件:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 该代码需要在指定目录运行
|
||||||
|
cd PaddleOCR/configs/rec/multi_language/
|
||||||
|
# 通过-l或者--language参数设置需要生成的语种的配置文件,该命令会将默认参数写入配置文件
|
||||||
|
python3 generate_multi_language_configs.py -l it
|
||||||
|
```
|
||||||
|
|
||||||
|
- 如果您的数据放置在其他位置,或希望使用自己的字典,可以通过指定相关参数来生成配置文件:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# -l或者--language字段是必须的
|
||||||
|
# --train修改训练集,--val修改验证集,--data_dir修改数据集目录,--dict修改字典路径, -o修改对应默认参数
|
||||||
|
cd PaddleOCR/configs/rec/multi_language/
|
||||||
|
python3 generate_multi_language_configs.py -l it \ # 语种
|
||||||
|
--train {path/of/train_label.txt} \ # 训练标签文件的路径
|
||||||
|
--val {path/of/val_label.txt} \ # 验证集标签文件的路径
|
||||||
|
--data_dir {train_data/path} \ # 训练数据的根目录
|
||||||
|
--dict {path/of/dict} \ # 字典文件路径
|
||||||
|
-o Global.use_gpu=False # 是否使用gpu
|
||||||
|
...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 手动修改配置文件
|
||||||
|
|
||||||
|
您也可以手动修改模版中的以下几个字段:
|
||||||
|
|
||||||
|
```
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 500
|
||||||
|
...
|
||||||
|
character_type: it # 需要识别的语种
|
||||||
|
character_dict_path: {path/of/dict} # 字典文件所在路径
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: train_data/ # 数据存放根目录
|
||||||
|
label_file_list: ["./train_data/train_list.txt"] # 训练集label路径
|
||||||
|
...
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: train_data/ # 数据存放根目录
|
||||||
|
label_file_list: ["./train_data/val_list.txt"] # 验证集label路径
|
||||||
|
...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
目前PaddleOCR支持的多语言算法有:
|
||||||
|
|
||||||
|
| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
|
||||||
|
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
|
||||||
|
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
|
||||||
|
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
|
||||||
|
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
|
||||||
|
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
|
||||||
|
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
|
||||||
|
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
|
||||||
|
| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 意大利语 | it |
|
||||||
|
| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 西班牙语 | xi |
|
||||||
|
| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 葡萄牙语 | pu |
|
||||||
|
| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 俄罗斯语 | ru |
|
||||||
|
| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯语 | ar |
|
||||||
|
| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 印地语 | hi |
|
||||||
|
| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 维吾尔语 | ug |
|
||||||
|
| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 波斯语 | fa |
|
||||||
|
| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌尔都语 | ur |
|
||||||
|
| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(latin)语 | rs |
|
||||||
|
| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 欧西坦语 | oc |
|
||||||
|
| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 马拉地语 | mr |
|
||||||
|
| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 尼泊尔语 | ne |
|
||||||
|
| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(cyrillic)语 | rsc |
|
||||||
|
| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 保加利亚语 | bg |
|
||||||
|
| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌克兰语 | uk |
|
||||||
|
| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 白俄罗斯语 | be |
|
||||||
|
| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰卢固语 | te |
|
||||||
|
| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 卡纳达语 | ka |
|
||||||
|
| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰米尔语 | ta |
|
||||||
|
|
||||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
|
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ Start training:
|
||||||
```
|
```
|
||||||
# Set PYTHONPATH path
|
# Set PYTHONPATH path
|
||||||
export PYTHONPATH=$PYTHONPATH:.
|
export PYTHONPATH=$PYTHONPATH:.
|
||||||
# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
|
# GPU training Support single card and multi-card training, specify the card number through --gpus.
|
||||||
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
|
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||||
```
|
```
|
||||||
|
|
|
@ -76,7 +76,7 @@ You can also use `-o` to change the training parameters without modifying the ym
|
||||||
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||||
|
|
||||||
# multi-GPU training
|
# multi-GPU training
|
||||||
# Set the GPU ID used by the '--gpus' parameter; If your paddle version is less than 2.0rc1, please use '--selected_gpus'
|
# Set the GPU ID used by the '--gpus' parameter.
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -315,10 +315,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
||||||
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
|
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
|
||||||
### 4. MULTILINGAUL MODEL INFERENCE
|
### 4. MULTILINGAUL MODEL INFERENCE
|
||||||
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
|
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
|
||||||
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/` path, such as Korean recognition:
|
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
|
||||||
|
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/korean.ttf"
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
|
||||||
```
|
```
|
||||||
![](../imgs_words/korean/1.jpg)
|
![](../imgs_words/korean/1.jpg)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility.
|
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility.
|
||||||
|
|
||||||
PaddleOCR working environment:
|
PaddleOCR working environment:
|
||||||
- PaddlePaddle 1.8+, Recommend PaddlePaddle 2.0rc1
|
- PaddlePaddle 2.0.0
|
||||||
- python3.7
|
- python3.7
|
||||||
- glibc 2.23
|
- glibc 2.23
|
||||||
|
|
||||||
|
@ -38,10 +38,10 @@ sudo docker container exec -it ppocr /bin/bash
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
|
|
||||||
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
|
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
|
||||||
python3 -m pip install paddlepaddle-gpu==2.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
|
|
||||||
# If you only have cpu on your machine, please run the following command to install
|
# If you only have cpu on your machine, please run the following command to install
|
||||||
python3 -m pip install paddlepaddle==2.0rc1 -i https://mirror.baidu.com/pypi/simple
|
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||||
```
|
```
|
||||||
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
||||||
|
|
||||||
|
|
|
@ -93,7 +93,7 @@ python3 generate_multi_language_configs.py -l it \
|
||||||
|model name|description|config|model size|download|
|
|model name|description|config|model size|download|
|
||||||
| --- | --- | --- | --- | --- |
|
| --- | --- | --- | --- | --- |
|
||||||
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
||||||
| german_mobile_v2.0_rec |Lightweight model for French recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
| german_mobile_v2.0_rec |Lightweight model for German recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
||||||
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
||||||
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
||||||
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
||||||
|
|
|
@ -266,15 +266,116 @@ Eval:
|
||||||
<a name="Multi_language"></a>
|
<a name="Multi_language"></a>
|
||||||
- Multi-language
|
- Multi-language
|
||||||
|
|
||||||
PaddleOCR also provides multi-language. The configuration file in `configs/rec/multi_languages` provides multi-language configuration files. Currently, the multi-language algorithms supported by PaddleOCR are:
|
PaddleOCR currently supports 26 (except Chinese) language recognition. A multi-language configuration file template is
|
||||||
|
provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
|
||||||
|
|
||||||
|
There are two ways to create the required configuration file::
|
||||||
|
|
||||||
|
1. Automatically generated by script
|
||||||
|
|
||||||
|
[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) Can help you generate configuration files for multi-language models
|
||||||
|
|
||||||
|
- Take Italian as an example, if your data is prepared in the following format:
|
||||||
|
```
|
||||||
|
|-train_data
|
||||||
|
|- it_train.txt # train_set label
|
||||||
|
|- it_val.txt # val_set label
|
||||||
|
|- data
|
||||||
|
|- word_001.jpg
|
||||||
|
|- word_002.jpg
|
||||||
|
|- word_003.jpg
|
||||||
|
| ...
|
||||||
|
```
|
||||||
|
|
||||||
|
You can use the default parameters to generate a configuration file:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# The code needs to be run in the specified directory
|
||||||
|
cd PaddleOCR/configs/rec/multi_language/
|
||||||
|
# Set the configuration file of the language to be generated through the -l or --language parameter.
|
||||||
|
# This command will write the default parameters into the configuration file
|
||||||
|
python3 generate_multi_language_configs.py -l it
|
||||||
|
```
|
||||||
|
|
||||||
|
- If your data is placed in another location, or you want to use your own dictionary, you can generate the configuration file by specifying the relevant parameters:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# -l or --language field is required
|
||||||
|
# --train to modify the training set
|
||||||
|
# --val to modify the validation set
|
||||||
|
# --data_dir to modify the data set directory
|
||||||
|
# --dict to modify the dict path
|
||||||
|
# -o to modify the corresponding default parameters
|
||||||
|
cd PaddleOCR/configs/rec/multi_language/
|
||||||
|
python3 generate_multi_language_configs.py -l it \ # language
|
||||||
|
--train {path/of/train_label.txt} \ # path of train_label
|
||||||
|
--val {path/of/val_label.txt} \ # path of val_label
|
||||||
|
--data_dir {train_data/path} \ # root directory of training data
|
||||||
|
--dict {path/of/dict} \ # path of dict
|
||||||
|
-o Global.use_gpu=False # whether to use gpu
|
||||||
|
...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Manually modify the configuration file
|
||||||
|
|
||||||
|
You can also manually modify the following fields in the template:
|
||||||
|
|
||||||
|
```
|
||||||
|
Global:
|
||||||
|
use_gpu: True
|
||||||
|
epoch_num: 500
|
||||||
|
...
|
||||||
|
character_type: it # language
|
||||||
|
character_dict_path: {path/of/dict} # path of dict
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: train_data/ # root directory of training data
|
||||||
|
label_file_list: ["./train_data/train_list.txt"] # train label path
|
||||||
|
...
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: SimpleDataSet
|
||||||
|
data_dir: train_data/ # root directory of val data
|
||||||
|
label_file_list: ["./train_data/val_list.txt"] # val label path
|
||||||
|
...
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Currently, the multi-language algorithms supported by PaddleOCR are:
|
||||||
|
|
||||||
|
| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
|
||||||
|
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
|
||||||
|
| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
|
||||||
|
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
|
||||||
|
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
|
||||||
|
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
|
||||||
|
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
|
||||||
|
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
|
||||||
|
| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Italian | it |
|
||||||
|
| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Spanish | xi |
|
||||||
|
| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Portuguese | pu |
|
||||||
|
| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Russia | ru |
|
||||||
|
| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Arabic | ar |
|
||||||
|
| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Hindi | hi |
|
||||||
|
| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Uyghur | ug |
|
||||||
|
| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Persian(Farsi) | fa |
|
||||||
|
| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Urdu | ur |
|
||||||
|
| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(latin) | rs |
|
||||||
|
| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Occitan | oc |
|
||||||
|
| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Marathi | mr |
|
||||||
|
| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Nepali | ne |
|
||||||
|
| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(cyrillic) | rsc |
|
||||||
|
| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Bulgarian | bg |
|
||||||
|
| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Ukranian | uk |
|
||||||
|
| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Belarusian | be |
|
||||||
|
| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Telugu | te |
|
||||||
|
| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Kannada | ka |
|
||||||
|
| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Tamil | ta |
|
||||||
|
|
||||||
| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
|
|
||||||
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
|
|
||||||
| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English |
|
|
||||||
| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
|
|
||||||
| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
|
|
||||||
| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
|
|
||||||
| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
|
|
||||||
|
|
||||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -33,7 +33,7 @@ import paddle.distributed as dist
|
||||||
|
|
||||||
from ppocr.data.imaug import transform, create_operators
|
from ppocr.data.imaug import transform, create_operators
|
||||||
from ppocr.data.simple_dataset import SimpleDataSet
|
from ppocr.data.simple_dataset import SimpleDataSet
|
||||||
from ppocr.data.lmdb_dataset import LMDBDateSet
|
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||||
|
|
||||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||||
|
|
||||||
|
@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp)
|
||||||
signal.signal(signal.SIGTERM, term_mp)
|
signal.signal(signal.SIGTERM, term_mp)
|
||||||
|
|
||||||
|
|
||||||
def build_dataloader(config, mode, device, logger):
|
def build_dataloader(config, mode, device, logger, seed=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
support_dict = ['SimpleDataSet', 'LMDBDataSet']
|
||||||
module_name = config[mode]['dataset']['name']
|
module_name = config[mode]['dataset']['name']
|
||||||
assert module_name in support_dict, Exception(
|
assert module_name in support_dict, Exception(
|
||||||
'DataSet only support {}'.format(support_dict))
|
'DataSet only support {}'.format(support_dict))
|
||||||
assert mode in ['Train', 'Eval', 'Test'
|
assert mode in ['Train', 'Eval', 'Test'
|
||||||
], "Mode should be Train, Eval or Test."
|
], "Mode should be Train, Eval or Test."
|
||||||
|
|
||||||
dataset = eval(module_name)(config, mode, logger)
|
dataset = eval(module_name)(config, mode, logger, seed)
|
||||||
loader_config = config[mode]['loader']
|
loader_config = config[mode]['loader']
|
||||||
batch_size = loader_config['batch_size_per_card']
|
batch_size = loader_config['batch_size_per_card']
|
||||||
drop_last = loader_config['drop_last']
|
drop_last = loader_config['drop_last']
|
||||||
|
shuffle = loader_config['shuffle']
|
||||||
num_workers = loader_config['num_workers']
|
num_workers = loader_config['num_workers']
|
||||||
if 'use_shared_memory' in loader_config.keys():
|
if 'use_shared_memory' in loader_config.keys():
|
||||||
use_shared_memory = loader_config['use_shared_memory']
|
use_shared_memory = loader_config['use_shared_memory']
|
||||||
|
@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger):
|
||||||
batch_sampler = DistributedBatchSampler(
|
batch_sampler = DistributedBatchSampler(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=shuffle,
|
||||||
drop_last=drop_last)
|
drop_last=drop_last)
|
||||||
else:
|
else:
|
||||||
#Distribute data to single card
|
#Distribute data to single card
|
||||||
batch_sampler = BatchSampler(
|
batch_sampler = BatchSampler(
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=shuffle,
|
||||||
drop_last=drop_last)
|
drop_last=drop_last)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
||||||
from .make_shrink_map import MakeShrinkMap
|
from .make_shrink_map import MakeShrinkMap
|
||||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||||
|
|
||||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
|
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .operators import *
|
from .operators import *
|
||||||
from .label_ops import *
|
from .label_ops import *
|
||||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import string
|
||||||
|
|
||||||
|
|
||||||
class ClsLabelEncode(object):
|
class ClsLabelEncode(object):
|
||||||
|
@ -92,18 +93,28 @@ class BaseRecLabelEncode(object):
|
||||||
character_type='ch',
|
character_type='ch',
|
||||||
use_space_char=False):
|
use_space_char=False):
|
||||||
support_character_type = [
|
support_character_type = [
|
||||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
|
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
||||||
|
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
||||||
|
'mr', 'ne'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
||||||
self.max_text_len = max_text_length
|
self.max_text_len = max_text_length
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
elif character_type == "EN_symbol":
|
||||||
|
# same with ASTER setting (use 94 char).
|
||||||
|
self.character_str = string.printable[:-6]
|
||||||
|
dict_character = list(self.character_str)
|
||||||
|
elif character_type in support_character_type:
|
||||||
self.character_str = ""
|
self.character_str = ""
|
||||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||||
|
character_type)
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
@ -112,11 +123,6 @@ class BaseRecLabelEncode(object):
|
||||||
if use_space_char:
|
if use_space_char:
|
||||||
self.character_str += " "
|
self.character_str += " "
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type == "en_sensitive":
|
|
||||||
# same with ASTER setting (use 94 char).
|
|
||||||
import string
|
|
||||||
self.character_str = string.printable[:-6]
|
|
||||||
dict_character = list(self.character_str)
|
|
||||||
self.character_type = character_type
|
self.character_type = character_type
|
||||||
dict_character = self.add_special_char(dict_character)
|
dict_character = self.add_special_char(dict_character)
|
||||||
self.dict = {}
|
self.dict = {}
|
||||||
|
@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
||||||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||||
% beg_or_end
|
% beg_or_end
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLabelEncode(BaseRecLabelEncode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length=25,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='en',
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SRNLabelEncode,
|
||||||
|
self).__init__(max_text_length, character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
|
text = self.encode(text)
|
||||||
|
char_num = len(self.character_str)
|
||||||
|
if text is None:
|
||||||
|
return None
|
||||||
|
if len(text) > self.max_text_len:
|
||||||
|
return None
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text = text + [char_num] * (self.max_text_len - len(text))
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
|
@ -12,20 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -77,6 +63,26 @@ class RecResizeImg(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class SRNRecResizeImg(object):
|
||||||
|
def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
norm_img = resize_norm_img_srn(img, self.image_shape)
|
||||||
|
data['image'] = norm_img
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
|
||||||
|
|
||||||
|
data['encoder_word_pos'] = encoder_word_pos
|
||||||
|
data['gsrm_word_pos'] = gsrm_word_pos
|
||||||
|
data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
|
||||||
|
data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def resize_norm_img(img, image_shape):
|
def resize_norm_img(img, image_shape):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
h = img.shape[0]
|
h = img.shape[0]
|
||||||
|
@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
|
||||||
def resize_norm_img_chinese(img, image_shape):
|
def resize_norm_img_chinese(img, image_shape):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
# todo: change to 0 and modified image shape
|
# todo: change to 0 and modified image shape
|
||||||
max_wh_ratio = 0
|
max_wh_ratio = imgW * 1.0 / imgH
|
||||||
h, w = img.shape[0], img.shape[1]
|
h, w = img.shape[0], img.shape[1]
|
||||||
ratio = w * 1.0 / h
|
ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, ratio)
|
max_wh_ratio = max(max_wh_ratio, ratio)
|
||||||
|
@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
|
||||||
return padding_im
|
return padding_im
|
||||||
|
|
||||||
|
|
||||||
|
def resize_norm_img_srn(img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def srn_other_inputs(image_shape, num_heads, max_text_length):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
||||||
|
[num_heads, 1, 1]) * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
||||||
|
[num_heads, 1, 1]) * [-1e9]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def flag():
|
def flag():
|
||||||
"""
|
"""
|
||||||
flag
|
flag
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class LMDBDateSet(Dataset):
|
class LMDBDateSet(Dataset):
|
||||||
def __init__(self, config, mode, logger):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(LMDBDateSet, self).__init__()
|
super(LMDBDateSet, self).__init__()
|
||||||
|
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
|
|
|
@ -20,7 +20,7 @@ from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class SimpleDataSet(Dataset):
|
class SimpleDataSet(Dataset):
|
||||||
def __init__(self, config, mode, logger):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(SimpleDataSet, self).__init__()
|
super(SimpleDataSet, self).__init__()
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
|
||||||
self.data_dir = dataset_config['data_dir']
|
self.data_dir = dataset_config['data_dir']
|
||||||
self.do_shuffle = loader_config['shuffle']
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
|
@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
|
||||||
for idx, file in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
with open(file, "rb") as f:
|
with open(file, "rb") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
random.seed(self.seed)
|
||||||
lines = random.sample(lines,
|
lines = random.sample(lines,
|
||||||
round(len(lines) * ratio_list[idx]))
|
round(len(lines) * ratio_list[idx]))
|
||||||
data_lines.extend(lines)
|
data_lines.extend(lines)
|
||||||
|
@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
|
||||||
|
|
||||||
def shuffle_data_random(self):
|
def shuffle_data_random(self):
|
||||||
if self.do_shuffle:
|
if self.do_shuffle:
|
||||||
|
random.seed(self.seed)
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -23,11 +23,14 @@ def build_loss(config):
|
||||||
|
|
||||||
# rec loss
|
# rec loss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
|
from .rec_srn_loss import SRNLoss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
|
support_dict = [
|
||||||
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
|
||||||
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLoss(nn.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(SRNLoss, self).__init__()
|
||||||
|
self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
predict = predicts['predict']
|
||||||
|
word_predict = predicts['word_out']
|
||||||
|
gsrm_predict = predicts['gsrm_out']
|
||||||
|
label = batch[1]
|
||||||
|
|
||||||
|
casted_label = paddle.cast(x=label, dtype='int64')
|
||||||
|
casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
|
||||||
|
|
||||||
|
cost_word = self.loss_func(word_predict, label=casted_label)
|
||||||
|
cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
|
||||||
|
cost_vsfd = self.loss_func(predict, label=casted_label)
|
||||||
|
|
||||||
|
cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
|
||||||
|
cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
|
||||||
|
cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
|
||||||
|
|
||||||
|
sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
|
||||||
|
|
||||||
|
return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
|
|
@ -33,8 +33,6 @@ class RecMetric(object):
|
||||||
if pred == target:
|
if pred == target:
|
||||||
correct_num += 1
|
correct_num += 1
|
||||||
all_num += 1
|
all_num += 1
|
||||||
# if all_num < 10 and kwargs.get('show_str', False):
|
|
||||||
# print('{} -> {}'.format(pred, target))
|
|
||||||
self.correct_num += correct_num
|
self.correct_num += correct_num
|
||||||
self.all_num += all_num
|
self.all_num += all_num
|
||||||
self.norm_edit_dis += norm_edit_dis
|
self.norm_edit_dis += norm_edit_dis
|
||||||
|
@ -50,7 +48,7 @@ class RecMetric(object):
|
||||||
'norm_edit_dis': 0,
|
'norm_edit_dis': 0,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
acc = self.correct_num / self.all_num
|
acc = 1.0 * self.correct_num / self.all_num
|
||||||
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
||||||
self.reset()
|
self.reset()
|
||||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||||
|
|
|
@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
|
||||||
config["Head"]['in_channels'] = in_channels
|
config["Head"]['in_channels'] = in_channels
|
||||||
self.head = build_head(config["Head"])
|
self.head = build_head(config["Head"])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, data=None):
|
||||||
if self.use_transform:
|
if self.use_transform:
|
||||||
x = self.transform(x)
|
x = self.transform(x)
|
||||||
x = self.backbone(x)
|
x = self.backbone(x)
|
||||||
if self.use_neck:
|
if self.use_neck:
|
||||||
x = self.neck(x)
|
x = self.neck(x)
|
||||||
x = self.head(x)
|
if data is None:
|
||||||
|
x = self.head(x)
|
||||||
|
else:
|
||||||
|
x = self.head(x, data)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -24,7 +24,8 @@ def build_backbone(config, model_type):
|
||||||
elif model_type == 'rec' or model_type == 'cls':
|
elif model_type == 'rec' or model_type == 'cls':
|
||||||
from .rec_mobilenet_v3 import MobileNetV3
|
from .rec_mobilenet_v3 import MobileNetV3
|
||||||
from .rec_resnet_vd import ResNet
|
from .rec_resnet_vd import ResNet
|
||||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
|
from .rec_resnet_fpn import ResNetFPN
|
||||||
|
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer):
|
||||||
[5, 72, 40, True, 'relu', 2],
|
[5, 72, 40, True, 'relu', 2],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[3, 240, 80, False, 'hard_swish', 2],
|
[3, 240, 80, False, 'hardswish', 2],
|
||||||
[3, 200, 80, False, 'hard_swish', 1],
|
[3, 200, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 480, 112, True, 'hard_swish', 1],
|
[3, 480, 112, True, 'hardswish', 1],
|
||||||
[3, 672, 112, True, 'hard_swish', 1],
|
[3, 672, 112, True, 'hardswish', 1],
|
||||||
[5, 672, 160, True, 'hard_swish', 2],
|
[5, 672, 160, True, 'hardswish', 2],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 960
|
cls_ch_squeeze = 960
|
||||||
elif model_name == "small":
|
elif model_name == "small":
|
||||||
|
@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer):
|
||||||
[3, 16, 16, True, 'relu', 2],
|
[3, 16, 16, True, 'relu', 2],
|
||||||
[3, 72, 24, False, 'relu', 2],
|
[3, 72, 24, False, 'relu', 2],
|
||||||
[3, 88, 24, False, 'relu', 1],
|
[3, 88, 24, False, 'relu', 1],
|
||||||
[5, 96, 40, True, 'hard_swish', 2],
|
[5, 96, 40, True, 'hardswish', 2],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 120, 48, True, 'hard_swish', 1],
|
[5, 120, 48, True, 'hardswish', 1],
|
||||||
[5, 144, 48, True, 'hard_swish', 1],
|
[5, 144, 48, True, 'hardswish', 1],
|
||||||
[5, 288, 96, True, 'hard_swish', 2],
|
[5, 288, 96, True, 'hardswish', 2],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 576
|
cls_ch_squeeze = 576
|
||||||
else:
|
else:
|
||||||
|
@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv1')
|
name='conv1')
|
||||||
|
|
||||||
self.stages = []
|
self.stages = []
|
||||||
|
@ -112,7 +112,8 @@ class MobileNetV3(nn.Layer):
|
||||||
inplanes = make_divisible(inplanes * scale)
|
inplanes = make_divisible(inplanes * scale)
|
||||||
for (k, exp, c, se, nl, s) in cfg:
|
for (k, exp, c, se, nl, s) in cfg:
|
||||||
se = se and not self.disable_se
|
se = se and not self.disable_se
|
||||||
if s == 2 and i > 2:
|
start_idx = 2 if model_name == 'large' else 0
|
||||||
|
if s == 2 and i > start_idx:
|
||||||
self.out_channels.append(inplanes)
|
self.out_channels.append(inplanes)
|
||||||
self.stages.append(nn.Sequential(*block_list))
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
block_list = []
|
block_list = []
|
||||||
|
@ -137,7 +138,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv_last'))
|
name='conv_last'))
|
||||||
self.stages.append(nn.Sequential(*block_list))
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||||
|
@ -191,10 +192,11 @@ class ConvBNLayer(nn.Layer):
|
||||||
if self.if_act:
|
if self.if_act:
|
||||||
if self.act == "relu":
|
if self.act == "relu":
|
||||||
x = F.relu(x)
|
x = F.relu(x)
|
||||||
elif self.act == "hard_swish":
|
elif self.act == "hardswish":
|
||||||
x = F.activation.hard_swish(x)
|
x = F.hardswish(x)
|
||||||
else:
|
else:
|
||||||
print("The activation function is selected incorrectly.")
|
print("The activation function({}) is selected incorrectly.".
|
||||||
|
format(self.act))
|
||||||
exit()
|
exit()
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -281,5 +283,5 @@ class SEModule(nn.Layer):
|
||||||
outputs = self.conv1(outputs)
|
outputs = self.conv1(outputs)
|
||||||
outputs = F.relu(outputs)
|
outputs = F.relu(outputs)
|
||||||
outputs = self.conv2(outputs)
|
outputs = self.conv2(outputs)
|
||||||
outputs = F.activation.hard_sigmoid(outputs)
|
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||||
return inputs * outputs
|
return inputs * outputs
|
||||||
|
|
|
@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer):
|
||||||
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[5, 120, 40, True, 'relu', 1],
|
[5, 120, 40, True, 'relu', 1],
|
||||||
[3, 240, 80, False, 'hard_swish', 1],
|
[3, 240, 80, False, 'hardswish', 1],
|
||||||
[3, 200, 80, False, 'hard_swish', 1],
|
[3, 200, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 184, 80, False, 'hard_swish', 1],
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
[3, 480, 112, True, 'hard_swish', 1],
|
[3, 480, 112, True, 'hardswish', 1],
|
||||||
[3, 672, 112, True, 'hard_swish', 1],
|
[3, 672, 112, True, 'hardswish', 1],
|
||||||
[5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
|
[5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
[5, 960, 160, True, 'hard_swish', 1],
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 960
|
cls_ch_squeeze = 960
|
||||||
elif model_name == "small":
|
elif model_name == "small":
|
||||||
|
@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer):
|
||||||
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
|
||||||
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
|
||||||
[3, 88, 24, False, 'relu', 1],
|
[3, 88, 24, False, 'relu', 1],
|
||||||
[5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
|
[5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 240, 40, True, 'hard_swish', 1],
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
[5, 120, 48, True, 'hard_swish', 1],
|
[5, 120, 48, True, 'hardswish', 1],
|
||||||
[5, 144, 48, True, 'hard_swish', 1],
|
[5, 144, 48, True, 'hardswish', 1],
|
||||||
[5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
|
[5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
[5, 576, 96, True, 'hard_swish', 1],
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
]
|
]
|
||||||
cls_ch_squeeze = 576
|
cls_ch_squeeze = 576
|
||||||
else:
|
else:
|
||||||
|
@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=1,
|
padding=1,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv1')
|
name='conv1')
|
||||||
i = 0
|
i = 0
|
||||||
block_list = []
|
block_list = []
|
||||||
|
@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer):
|
||||||
padding=0,
|
padding=0,
|
||||||
groups=1,
|
groups=1,
|
||||||
if_act=True,
|
if_act=True,
|
||||||
act='hard_swish',
|
act='hardswish',
|
||||||
name='conv_last')
|
name='conv_last')
|
||||||
|
|
||||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||||
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
#you may not use this file except in compliance with the License.
|
||||||
|
#You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
#Unless required by applicable law or agreed to in writing, software
|
||||||
|
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
#See the License for the specific language governing permissions and
|
||||||
|
#limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = ["ResNetFPN"]
|
||||||
|
|
||||||
|
|
||||||
|
class ResNetFPN(nn.Layer):
|
||||||
|
def __init__(self, in_channels=1, layers=50, **kwargs):
|
||||||
|
super(ResNetFPN, self).__init__()
|
||||||
|
supported_layers = {
|
||||||
|
18: {
|
||||||
|
'depth': [2, 2, 2, 2],
|
||||||
|
'block_class': BasicBlock
|
||||||
|
},
|
||||||
|
34: {
|
||||||
|
'depth': [3, 4, 6, 3],
|
||||||
|
'block_class': BasicBlock
|
||||||
|
},
|
||||||
|
50: {
|
||||||
|
'depth': [3, 4, 6, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
},
|
||||||
|
101: {
|
||||||
|
'depth': [3, 4, 23, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
},
|
||||||
|
152: {
|
||||||
|
'depth': [3, 8, 36, 3],
|
||||||
|
'block_class': BottleneckBlock
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
|
||||||
|
num_filters = [64, 128, 256, 512]
|
||||||
|
self.depth = supported_layers[layers]['depth']
|
||||||
|
self.F = []
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=7,
|
||||||
|
stride=2,
|
||||||
|
act="relu",
|
||||||
|
name="conv1")
|
||||||
|
self.block_list = []
|
||||||
|
in_ch = 64
|
||||||
|
if layers >= 50:
|
||||||
|
for block in range(len(self.depth)):
|
||||||
|
for i in range(self.depth[block]):
|
||||||
|
if layers in [101, 152] and block == 2:
|
||||||
|
if i == 0:
|
||||||
|
conv_name = "res" + str(block + 2) + "a"
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
block_list = self.add_sublayer(
|
||||||
|
"bottleneckBlock_{}_{}".format(block, i),
|
||||||
|
BottleneckBlock(
|
||||||
|
in_channels=in_ch,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=stride_list[block] if i == 0 else 1,
|
||||||
|
name=conv_name))
|
||||||
|
in_ch = num_filters[block] * 4
|
||||||
|
self.block_list.append(block_list)
|
||||||
|
self.F.append(block_list)
|
||||||
|
else:
|
||||||
|
for block in range(len(self.depth)):
|
||||||
|
for i in range(self.depth[block]):
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
if i == 0 and block != 0:
|
||||||
|
stride = (2, 1)
|
||||||
|
else:
|
||||||
|
stride = (1, 1)
|
||||||
|
basic_block = self.add_sublayer(
|
||||||
|
conv_name,
|
||||||
|
BasicBlock(
|
||||||
|
in_channels=in_ch,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=stride_list[block] if i == 0 else 1,
|
||||||
|
is_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
in_ch = basic_block.out_channels
|
||||||
|
self.block_list.append(basic_block)
|
||||||
|
out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
|
||||||
|
self.base_block = []
|
||||||
|
self.conv_trans = []
|
||||||
|
self.bn_block = []
|
||||||
|
for i in [-2, -3]:
|
||||||
|
in_channels = out_ch_list[i + 1] + out_ch_list[i]
|
||||||
|
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_0".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_ch_list[i],
|
||||||
|
kernel_size=1,
|
||||||
|
weight_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_1".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=out_ch_list[i],
|
||||||
|
out_channels=out_ch_list[i],
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_2".format(i),
|
||||||
|
nn.BatchNorm(
|
||||||
|
num_channels=out_ch_list[i],
|
||||||
|
act="relu",
|
||||||
|
param_attr=ParamAttr(trainable=True),
|
||||||
|
bias_attr=ParamAttr(trainable=True))))
|
||||||
|
self.base_block.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"F_{}_base_block_3".format(i),
|
||||||
|
nn.Conv2D(
|
||||||
|
in_channels=out_ch_list[i],
|
||||||
|
out_channels=512,
|
||||||
|
kernel_size=1,
|
||||||
|
bias_attr=ParamAttr(trainable=True),
|
||||||
|
weight_attr=ParamAttr(trainable=True))))
|
||||||
|
self.out_channels = 512
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
fpn_list = []
|
||||||
|
F = []
|
||||||
|
for i in range(len(self.depth)):
|
||||||
|
fpn_list.append(np.sum(self.depth[:i + 1]))
|
||||||
|
|
||||||
|
for i, block in enumerate(self.block_list):
|
||||||
|
x = block(x)
|
||||||
|
for number in fpn_list:
|
||||||
|
if i + 1 == number:
|
||||||
|
F.append(x)
|
||||||
|
base = F[-1]
|
||||||
|
|
||||||
|
j = 0
|
||||||
|
for i, block in enumerate(self.base_block):
|
||||||
|
if i % 3 == 0 and i < 6:
|
||||||
|
j = j + 1
|
||||||
|
b, c, w, h = F[-j - 1].shape
|
||||||
|
if [w, h] == list(base.shape[2:]):
|
||||||
|
base = base
|
||||||
|
else:
|
||||||
|
base = self.conv_trans[j - 1](base)
|
||||||
|
base = self.bn_block[j - 1](base)
|
||||||
|
base = paddle.concat([base, F[-j - 1]], axis=1)
|
||||||
|
base = block(base)
|
||||||
|
return base
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=2 if stride == (1, 1) else kernel_size,
|
||||||
|
dilation=2 if stride == (1, 1) else 1,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
|
||||||
|
bias_attr=False, )
|
||||||
|
|
||||||
|
if name == "conv1":
|
||||||
|
bn_name = "bn_" + name
|
||||||
|
else:
|
||||||
|
bn_name = "bn" + name[3:]
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name=name + '.output.1.w_0'),
|
||||||
|
bias_attr=ParamAttr(name=name + '.output.1.b_0'),
|
||||||
|
moving_mean_name=bn_name + "_mean",
|
||||||
|
moving_variance_name=bn_name + "_variance")
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ShortCut(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name, is_first=False):
|
||||||
|
super(ShortCut, self).__init__()
|
||||||
|
self.use_conv = True
|
||||||
|
|
||||||
|
if in_channels != out_channels or stride != 1 or is_first == True:
|
||||||
|
if stride == (1, 1):
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels, out_channels, 1, 1, name=name)
|
||||||
|
else: # stride==(2,2)
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels, out_channels, 1, stride, name=name)
|
||||||
|
else:
|
||||||
|
self.use_conv = False
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.use_conv:
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2b")
|
||||||
|
|
||||||
|
self.conv2 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2c")
|
||||||
|
|
||||||
|
self.short = ShortCut(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
stride=stride,
|
||||||
|
is_first=False,
|
||||||
|
name=name + "_branch1")
|
||||||
|
self.out_channels = out_channels * 4
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.conv0(x)
|
||||||
|
y = self.conv1(y)
|
||||||
|
y = self.conv2(y)
|
||||||
|
y = y + self.short(x)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, stride, name, is_first):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act='relu',
|
||||||
|
stride=stride,
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2b")
|
||||||
|
self.short = ShortCut(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
stride=stride,
|
||||||
|
is_first=is_first,
|
||||||
|
name=name + "_branch1")
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.conv0(x)
|
||||||
|
y = self.conv1(y)
|
||||||
|
y = y + self.short(x)
|
||||||
|
return F.relu(y)
|
|
@ -23,10 +23,13 @@ def build_head(config):
|
||||||
|
|
||||||
# rec head
|
# rec head
|
||||||
from .rec_ctc_head import CTCHead
|
from .rec_ctc_head import CTCHead
|
||||||
|
from .rec_srn_head import SRNHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
|
support_dict = [
|
||||||
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
|
||||||
|
]
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,279 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
from .self_attention import WrapEncoderForFeature
|
||||||
|
from .self_attention import WrapEncoder
|
||||||
|
from paddle.static import Program
|
||||||
|
from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
|
||||||
|
import paddle.fluid.framework as framework
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
gradient_clip = 10
|
||||||
|
|
||||||
|
|
||||||
|
class PVAM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||||
|
num_encoder_tus, hidden_dims):
|
||||||
|
super(PVAM, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_tus
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
# Transformer encoder
|
||||||
|
t = 256
|
||||||
|
c = 512
|
||||||
|
self.wrap_encoder_for_feature = WrapEncoderForFeature(
|
||||||
|
src_vocab_size=1,
|
||||||
|
max_length=t,
|
||||||
|
n_layer=self.num_encoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
# PVAM
|
||||||
|
self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels,
|
||||||
|
out_features=in_channels, )
|
||||||
|
self.emb = paddle.nn.Embedding(
|
||||||
|
num_embeddings=self.max_length, embedding_dim=in_channels)
|
||||||
|
self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels, out_features=1, bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
|
||||||
|
b, c, h, w = inputs.shape
|
||||||
|
conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
|
||||||
|
conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
|
||||||
|
# transformer encoder
|
||||||
|
b, t, c = conv_features.shape
|
||||||
|
|
||||||
|
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||||
|
word_features = self.wrap_encoder_for_feature(enc_inputs)
|
||||||
|
|
||||||
|
# pvam
|
||||||
|
b, t, c = word_features.shape
|
||||||
|
word_features = self.fc0(word_features)
|
||||||
|
word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
|
||||||
|
word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
|
||||||
|
word_pos_feature = self.emb(gsrm_word_pos)
|
||||||
|
word_pos_feature_ = paddle.reshape(word_pos_feature,
|
||||||
|
[-1, self.max_length, 1, c])
|
||||||
|
word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
|
||||||
|
y = word_pos_feature_ + word_features_
|
||||||
|
y = F.tanh(y)
|
||||||
|
attention_weight = self.fc1(y)
|
||||||
|
attention_weight = paddle.reshape(
|
||||||
|
attention_weight, shape=[-1, self.max_length, t])
|
||||||
|
attention_weight = F.softmax(attention_weight, axis=-1)
|
||||||
|
pvam_features = paddle.matmul(attention_weight,
|
||||||
|
word_features) #[b, max_length, c]
|
||||||
|
return pvam_features
|
||||||
|
|
||||||
|
|
||||||
|
class GSRM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, char_num, max_text_length, num_heads,
|
||||||
|
num_encoder_tus, num_decoder_tus, hidden_dims):
|
||||||
|
super(GSRM, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_tus
|
||||||
|
self.num_decoder_TUs = num_decoder_tus
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels, out_features=self.char_num)
|
||||||
|
self.wrap_encoder0 = WrapEncoder(
|
||||||
|
src_vocab_size=self.char_num + 1,
|
||||||
|
max_length=self.max_length,
|
||||||
|
n_layer=self.num_decoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
self.wrap_encoder1 = WrapEncoder(
|
||||||
|
src_vocab_size=self.char_num + 1,
|
||||||
|
max_length=self.max_length,
|
||||||
|
n_layer=self.num_decoder_TUs,
|
||||||
|
n_head=self.num_heads,
|
||||||
|
d_key=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_value=int(self.hidden_dims / self.num_heads),
|
||||||
|
d_model=self.hidden_dims,
|
||||||
|
d_inner_hid=self.hidden_dims,
|
||||||
|
prepostprocess_dropout=0.1,
|
||||||
|
attention_dropout=0.1,
|
||||||
|
relu_dropout=0.1,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da",
|
||||||
|
weight_sharing=True)
|
||||||
|
|
||||||
|
self.mul = lambda x: paddle.matmul(x=x,
|
||||||
|
y=self.wrap_encoder0.prepare_decoder.emb0.weight,
|
||||||
|
transpose_y=True)
|
||||||
|
|
||||||
|
def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2):
|
||||||
|
# ===== GSRM Visual-to-semantic embedding block =====
|
||||||
|
b, t, c = inputs.shape
|
||||||
|
pvam_features = paddle.reshape(inputs, [-1, c])
|
||||||
|
word_out = self.fc0(pvam_features)
|
||||||
|
word_ids = paddle.argmax(F.softmax(word_out), axis=1)
|
||||||
|
word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
|
||||||
|
|
||||||
|
#===== GSRM Semantic reasoning block =====
|
||||||
|
"""
|
||||||
|
This module is achieved through bi-transformers,
|
||||||
|
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
||||||
|
"""
|
||||||
|
pad_idx = self.char_num
|
||||||
|
|
||||||
|
word1 = paddle.cast(word_ids, "float32")
|
||||||
|
word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
|
||||||
|
word1 = paddle.cast(word1, "int64")
|
||||||
|
word1 = word1[:, :-1, :]
|
||||||
|
word2 = word_ids
|
||||||
|
|
||||||
|
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||||
|
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||||
|
|
||||||
|
gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
|
||||||
|
gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
|
||||||
|
|
||||||
|
gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
|
||||||
|
value=0.,
|
||||||
|
data_format="NLC")
|
||||||
|
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||||
|
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||||
|
|
||||||
|
gsrm_out = self.mul(gsrm_features)
|
||||||
|
|
||||||
|
b, t, c = gsrm_out.shape
|
||||||
|
gsrm_out = paddle.reshape(gsrm_out, [-1, c])
|
||||||
|
|
||||||
|
return gsrm_features, word_out, gsrm_out
|
||||||
|
|
||||||
|
|
||||||
|
class VSFD(nn.Layer):
|
||||||
|
def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
|
||||||
|
super(VSFD, self).__init__()
|
||||||
|
self.char_num = char_num
|
||||||
|
self.fc0 = paddle.nn.Linear(
|
||||||
|
in_features=in_channels * 2, out_features=pvam_ch)
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=pvam_ch, out_features=self.char_num)
|
||||||
|
|
||||||
|
def forward(self, pvam_feature, gsrm_feature):
|
||||||
|
b, t, c1 = pvam_feature.shape
|
||||||
|
b, t, c2 = gsrm_feature.shape
|
||||||
|
combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
|
||||||
|
img_comb_feature_ = paddle.reshape(
|
||||||
|
combine_feature_, shape=[-1, c1 + c2])
|
||||||
|
img_comb_feature_map = self.fc0(img_comb_feature_)
|
||||||
|
img_comb_feature_map = F.sigmoid(img_comb_feature_map)
|
||||||
|
img_comb_feature_map = paddle.reshape(
|
||||||
|
img_comb_feature_map, shape=[-1, t, c1])
|
||||||
|
combine_feature = img_comb_feature_map * pvam_feature + (
|
||||||
|
1.0 - img_comb_feature_map) * gsrm_feature
|
||||||
|
img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
|
||||||
|
|
||||||
|
out = self.fc1(img_comb_feature)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class SRNHead(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, max_text_length, num_heads,
|
||||||
|
num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
|
||||||
|
super(SRNHead, self).__init__()
|
||||||
|
self.char_num = out_channels
|
||||||
|
self.max_length = max_text_length
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.num_encoder_TUs = num_encoder_TUs
|
||||||
|
self.num_decoder_TUs = num_decoder_TUs
|
||||||
|
self.hidden_dims = hidden_dims
|
||||||
|
|
||||||
|
self.pvam = PVAM(
|
||||||
|
in_channels=in_channels,
|
||||||
|
char_num=self.char_num,
|
||||||
|
max_text_length=self.max_length,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_encoder_tus=self.num_encoder_TUs,
|
||||||
|
hidden_dims=self.hidden_dims)
|
||||||
|
|
||||||
|
self.gsrm = GSRM(
|
||||||
|
in_channels=in_channels,
|
||||||
|
char_num=self.char_num,
|
||||||
|
max_text_length=self.max_length,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
num_encoder_tus=self.num_encoder_TUs,
|
||||||
|
num_decoder_tus=self.num_decoder_TUs,
|
||||||
|
hidden_dims=self.hidden_dims)
|
||||||
|
self.vsfd = VSFD(in_channels=in_channels)
|
||||||
|
|
||||||
|
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||||
|
|
||||||
|
def forward(self, inputs, others):
|
||||||
|
encoder_word_pos = others[0]
|
||||||
|
gsrm_word_pos = others[1]
|
||||||
|
gsrm_slf_attn_bias1 = others[2]
|
||||||
|
gsrm_slf_attn_bias2 = others[3]
|
||||||
|
|
||||||
|
pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
|
||||||
|
|
||||||
|
gsrm_feature, word_out, gsrm_out = self.gsrm(
|
||||||
|
pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
|
final_out = self.vsfd(pvam_feature, gsrm_feature)
|
||||||
|
if not self.training:
|
||||||
|
final_out = F.softmax(final_out, axis=1)
|
||||||
|
|
||||||
|
_, decoded_out = paddle.topk(final_out, k=1)
|
||||||
|
|
||||||
|
predicts = OrderedDict([
|
||||||
|
('predict', final_out),
|
||||||
|
('pvam_feature', pvam_feature),
|
||||||
|
('decoded_out', decoded_out),
|
||||||
|
('word_out', word_out),
|
||||||
|
('gsrm_out', gsrm_out),
|
||||||
|
])
|
||||||
|
|
||||||
|
return predicts
|
|
@ -0,0 +1,409 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import ParamAttr, nn
|
||||||
|
from paddle import nn, ParamAttr
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import numpy as np
|
||||||
|
gradient_clip = 10
|
||||||
|
|
||||||
|
|
||||||
|
class WrapEncoderForFeature(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
max_length,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd,
|
||||||
|
postprocess_cmd,
|
||||||
|
weight_sharing,
|
||||||
|
bos_idx=0):
|
||||||
|
super(WrapEncoderForFeature, self).__init__()
|
||||||
|
|
||||||
|
self.prepare_encoder = PrepareEncoder(
|
||||||
|
src_vocab_size,
|
||||||
|
d_model,
|
||||||
|
max_length,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
bos_idx=bos_idx,
|
||||||
|
word_emb_param_name="src_word_emb_table")
|
||||||
|
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||||
|
d_inner_hid, prepostprocess_dropout,
|
||||||
|
attention_dropout, relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)
|
||||||
|
|
||||||
|
def forward(self, enc_inputs):
|
||||||
|
conv_features, src_pos, src_slf_attn_bias = enc_inputs
|
||||||
|
enc_input = self.prepare_encoder(conv_features, src_pos)
|
||||||
|
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class WrapEncoder(nn.Layer):
|
||||||
|
"""
|
||||||
|
embedder + encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
max_length,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd,
|
||||||
|
postprocess_cmd,
|
||||||
|
weight_sharing,
|
||||||
|
bos_idx=0):
|
||||||
|
super(WrapEncoder, self).__init__()
|
||||||
|
|
||||||
|
self.prepare_decoder = PrepareDecoder(
|
||||||
|
src_vocab_size,
|
||||||
|
d_model,
|
||||||
|
max_length,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
bos_idx=bos_idx)
|
||||||
|
self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
|
||||||
|
d_inner_hid, prepostprocess_dropout,
|
||||||
|
attention_dropout, relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)
|
||||||
|
|
||||||
|
def forward(self, enc_inputs):
|
||||||
|
src_word, src_pos, src_slf_attn_bias = enc_inputs
|
||||||
|
enc_input = self.prepare_decoder(src_word, src_pos)
|
||||||
|
enc_output = self.encoder(enc_input, src_slf_attn_bias)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder(nn.Layer):
|
||||||
|
"""
|
||||||
|
encoder
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_layer,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da"):
|
||||||
|
|
||||||
|
super(Encoder, self).__init__()
|
||||||
|
|
||||||
|
self.encoder_layers = list()
|
||||||
|
for i in range(n_layer):
|
||||||
|
self.encoder_layers.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"layer_%d" % i,
|
||||||
|
EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
|
||||||
|
prepostprocess_dropout, attention_dropout,
|
||||||
|
relu_dropout, preprocess_cmd,
|
||||||
|
postprocess_cmd)))
|
||||||
|
self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_input, attn_bias):
|
||||||
|
for encoder_layer in self.encoder_layers:
|
||||||
|
enc_output = encoder_layer(enc_input, attn_bias)
|
||||||
|
enc_input = enc_output
|
||||||
|
enc_output = self.processer(enc_output)
|
||||||
|
return enc_output
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderLayer(nn.Layer):
|
||||||
|
"""
|
||||||
|
EncoderLayer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
n_head,
|
||||||
|
d_key,
|
||||||
|
d_value,
|
||||||
|
d_model,
|
||||||
|
d_inner_hid,
|
||||||
|
prepostprocess_dropout,
|
||||||
|
attention_dropout,
|
||||||
|
relu_dropout,
|
||||||
|
preprocess_cmd="n",
|
||||||
|
postprocess_cmd="da"):
|
||||||
|
|
||||||
|
super(EncoderLayer, self).__init__()
|
||||||
|
self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
|
||||||
|
attention_dropout)
|
||||||
|
self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
|
||||||
|
self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
|
||||||
|
prepostprocess_dropout)
|
||||||
|
|
||||||
|
def forward(self, enc_input, attn_bias):
|
||||||
|
attn_output = self.self_attn(
|
||||||
|
self.preprocesser1(enc_input), None, None, attn_bias)
|
||||||
|
attn_output = self.postprocesser1(attn_output, enc_input)
|
||||||
|
ffn_output = self.ffn(self.preprocesser2(attn_output))
|
||||||
|
ffn_output = self.postprocesser2(ffn_output, attn_output)
|
||||||
|
return ffn_output
|
||||||
|
|
||||||
|
|
||||||
|
class MultiHeadAttention(nn.Layer):
|
||||||
|
"""
|
||||||
|
Multi-Head Attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
|
||||||
|
super(MultiHeadAttention, self).__init__()
|
||||||
|
self.n_head = n_head
|
||||||
|
self.d_key = d_key
|
||||||
|
self.d_value = d_value
|
||||||
|
self.d_model = d_model
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.q_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||||
|
self.k_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_key * n_head, bias_attr=False)
|
||||||
|
self.v_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_value * n_head, bias_attr=False)
|
||||||
|
self.proj_fc = paddle.nn.Linear(
|
||||||
|
in_features=d_value * n_head, out_features=d_model, bias_attr=False)
|
||||||
|
|
||||||
|
def _prepare_qkv(self, queries, keys, values, cache=None):
|
||||||
|
if keys is None: # self-attention
|
||||||
|
keys, values = queries, queries
|
||||||
|
static_kv = False
|
||||||
|
else: # cross-attention
|
||||||
|
static_kv = True
|
||||||
|
|
||||||
|
q = self.q_fc(queries)
|
||||||
|
q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
|
||||||
|
q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
|
if cache is not None and static_kv and "static_k" in cache:
|
||||||
|
# for encoder-decoder attention in inference and has cached
|
||||||
|
k = cache["static_k"]
|
||||||
|
v = cache["static_v"]
|
||||||
|
else:
|
||||||
|
k = self.k_fc(keys)
|
||||||
|
v = self.v_fc(values)
|
||||||
|
k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
|
||||||
|
k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
|
||||||
|
v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
|
||||||
|
v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
|
if cache is not None:
|
||||||
|
if static_kv and not "static_k" in cache:
|
||||||
|
# for encoder-decoder attention in inference and has not cached
|
||||||
|
cache["static_k"], cache["static_v"] = k, v
|
||||||
|
elif not static_kv:
|
||||||
|
# for decoder self-attention in inference
|
||||||
|
cache_k, cache_v = cache["k"], cache["v"]
|
||||||
|
k = paddle.concat([cache_k, k], axis=2)
|
||||||
|
v = paddle.concat([cache_v, v], axis=2)
|
||||||
|
cache["k"], cache["v"] = k, v
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def forward(self, queries, keys, values, attn_bias, cache=None):
|
||||||
|
# compute q ,k ,v
|
||||||
|
keys = queries if keys is None else keys
|
||||||
|
values = keys if values is None else values
|
||||||
|
q, k, v = self._prepare_qkv(queries, keys, values, cache)
|
||||||
|
|
||||||
|
# scale dot product attention
|
||||||
|
product = paddle.matmul(x=q, y=k, transpose_y=True)
|
||||||
|
product = product * self.d_model**-0.5
|
||||||
|
if attn_bias is not None:
|
||||||
|
product += attn_bias
|
||||||
|
weights = F.softmax(product)
|
||||||
|
if self.dropout_rate:
|
||||||
|
weights = F.dropout(
|
||||||
|
weights, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
out = paddle.matmul(weights, v)
|
||||||
|
|
||||||
|
# combine heads
|
||||||
|
out = paddle.transpose(out, perm=[0, 2, 1, 3])
|
||||||
|
out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
|
||||||
|
|
||||||
|
# project to output
|
||||||
|
out = self.proj_fc(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PrePostProcessLayer(nn.Layer):
|
||||||
|
"""
|
||||||
|
PrePostProcessLayer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, process_cmd, d_model, dropout_rate):
|
||||||
|
super(PrePostProcessLayer, self).__init__()
|
||||||
|
self.process_cmd = process_cmd
|
||||||
|
self.functors = []
|
||||||
|
for cmd in self.process_cmd:
|
||||||
|
if cmd == "a": # add residual connection
|
||||||
|
self.functors.append(lambda x, y: x + y if y is not None else x)
|
||||||
|
elif cmd == "n": # add layer normalization
|
||||||
|
self.functors.append(
|
||||||
|
self.add_sublayer(
|
||||||
|
"layer_norm_%d" % len(
|
||||||
|
self.sublayers(include_sublayers=False)),
|
||||||
|
paddle.nn.LayerNorm(
|
||||||
|
normalized_shape=d_model,
|
||||||
|
weight_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(1.)),
|
||||||
|
bias_attr=fluid.ParamAttr(
|
||||||
|
initializer=fluid.initializer.Constant(0.)))))
|
||||||
|
elif cmd == "d": # add dropout
|
||||||
|
self.functors.append(lambda x: F.dropout(
|
||||||
|
x, p=dropout_rate, mode="downscale_in_infer")
|
||||||
|
if dropout_rate else x)
|
||||||
|
|
||||||
|
def forward(self, x, residual=None):
|
||||||
|
for i, cmd in enumerate(self.process_cmd):
|
||||||
|
if cmd == "a":
|
||||||
|
x = self.functors[i](x, residual)
|
||||||
|
else:
|
||||||
|
x = self.functors[i](x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareEncoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
src_emb_dim,
|
||||||
|
src_max_len,
|
||||||
|
dropout_rate=0,
|
||||||
|
bos_idx=0,
|
||||||
|
word_emb_param_name=None,
|
||||||
|
pos_enc_param_name=None):
|
||||||
|
super(PrepareEncoder, self).__init__()
|
||||||
|
self.src_emb_dim = src_emb_dim
|
||||||
|
self.src_max_len = src_max_len
|
||||||
|
self.emb = paddle.nn.Embedding(
|
||||||
|
num_embeddings=self.src_max_len,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
sparse=True)
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
def forward(self, src_word, src_pos):
|
||||||
|
src_word_emb = src_word
|
||||||
|
src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
|
||||||
|
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||||
|
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||||
|
src_pos_enc = self.emb(src_pos)
|
||||||
|
src_pos_enc.stop_gradient = True
|
||||||
|
enc_input = src_word_emb + src_pos_enc
|
||||||
|
if self.dropout_rate:
|
||||||
|
out = F.dropout(
|
||||||
|
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
else:
|
||||||
|
out = enc_input
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class PrepareDecoder(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
src_vocab_size,
|
||||||
|
src_emb_dim,
|
||||||
|
src_max_len,
|
||||||
|
dropout_rate=0,
|
||||||
|
bos_idx=0,
|
||||||
|
word_emb_param_name=None,
|
||||||
|
pos_enc_param_name=None):
|
||||||
|
super(PrepareDecoder, self).__init__()
|
||||||
|
self.src_emb_dim = src_emb_dim
|
||||||
|
"""
|
||||||
|
self.emb0 = Embedding(num_embeddings=src_vocab_size,
|
||||||
|
embedding_dim=src_emb_dim)
|
||||||
|
"""
|
||||||
|
self.emb0 = paddle.nn.Embedding(
|
||||||
|
num_embeddings=src_vocab_size,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
padding_idx=bos_idx,
|
||||||
|
weight_attr=paddle.ParamAttr(
|
||||||
|
name=word_emb_param_name,
|
||||||
|
initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||||
|
self.emb1 = paddle.nn.Embedding(
|
||||||
|
num_embeddings=src_max_len,
|
||||||
|
embedding_dim=self.src_emb_dim,
|
||||||
|
weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
|
||||||
|
def forward(self, src_word, src_pos):
|
||||||
|
src_word = fluid.layers.cast(src_word, 'int64')
|
||||||
|
src_word = paddle.squeeze(src_word, axis=-1)
|
||||||
|
src_word_emb = self.emb0(src_word)
|
||||||
|
src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
|
||||||
|
src_pos = paddle.squeeze(src_pos, axis=-1)
|
||||||
|
src_pos_enc = self.emb1(src_pos)
|
||||||
|
src_pos_enc.stop_gradient = True
|
||||||
|
enc_input = src_word_emb + src_pos_enc
|
||||||
|
if self.dropout_rate:
|
||||||
|
out = F.dropout(
|
||||||
|
x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
else:
|
||||||
|
out = enc_input
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class FFN(nn.Layer):
|
||||||
|
"""
|
||||||
|
Feed-Forward Network
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, d_inner_hid, d_model, dropout_rate):
|
||||||
|
super(FFN, self).__init__()
|
||||||
|
self.dropout_rate = dropout_rate
|
||||||
|
self.fc1 = paddle.nn.Linear(
|
||||||
|
in_features=d_model, out_features=d_inner_hid)
|
||||||
|
self.fc2 = paddle.nn.Linear(
|
||||||
|
in_features=d_inner_hid, out_features=d_model)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden = self.fc1(x)
|
||||||
|
hidden = F.relu(hidden)
|
||||||
|
if self.dropout_rate:
|
||||||
|
hidden = F.dropout(
|
||||||
|
hidden, p=self.dropout_rate, mode="downscale_in_infer")
|
||||||
|
out = self.fc2(hidden)
|
||||||
|
return out
|
|
@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
|
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import string
|
||||||
import paddle
|
import paddle
|
||||||
from paddle.nn import functional as F
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
@ -24,19 +25,28 @@ class BaseRecLabelDecode(object):
|
||||||
character_type='ch',
|
character_type='ch',
|
||||||
use_space_char=False):
|
use_space_char=False):
|
||||||
support_character_type = [
|
support_character_type = [
|
||||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean', 'it',
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg',
|
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||||
'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'ne'
|
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||||
|
'ne', 'EN'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
|
||||||
if character_type == "en":
|
if character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type in ["ch", "french", "german", "japan", "korean"]:
|
elif character_type == "EN_symbol":
|
||||||
|
# same with ASTER setting (use 94 char).
|
||||||
|
self.character_str = string.printable[:-6]
|
||||||
|
dict_character = list(self.character_str)
|
||||||
|
elif character_type in support_character_type:
|
||||||
self.character_str = ""
|
self.character_str = ""
|
||||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
|
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||||
|
character_type)
|
||||||
with open(character_dict_path, "rb") as fin:
|
with open(character_dict_path, "rb") as fin:
|
||||||
lines = fin.readlines()
|
lines = fin.readlines()
|
||||||
for line in lines:
|
for line in lines:
|
||||||
|
@ -45,11 +55,7 @@ class BaseRecLabelDecode(object):
|
||||||
if use_space_char:
|
if use_space_char:
|
||||||
self.character_str += " "
|
self.character_str += " "
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
elif character_type == "en_sensitive":
|
|
||||||
# same with ASTER setting (use 94 char).
|
|
||||||
import string
|
|
||||||
self.character_str = string.printable[:-6]
|
|
||||||
dict_character = list(self.character_str)
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
self.character_type = character_type
|
self.character_type = character_type
|
||||||
|
@ -106,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
|
||||||
def __call__(self, preds, label=None, *args, **kwargs):
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
if isinstance(preds, paddle.Tensor):
|
if isinstance(preds, paddle.Tensor):
|
||||||
preds = preds.numpy()
|
preds = preds.numpy()
|
||||||
|
|
||||||
preds_idx = preds.argmax(axis=2)
|
preds_idx = preds.argmax(axis=2)
|
||||||
preds_prob = preds.max(axis=2)
|
preds_prob = preds.max(axis=2)
|
||||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||||
|
@ -155,3 +160,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
% beg_or_end
|
% beg_or_end
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|
||||||
|
class SRNLabelDecode(BaseRecLabelDecode):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
character_dict_path=None,
|
||||||
|
character_type='en',
|
||||||
|
use_space_char=False,
|
||||||
|
**kwargs):
|
||||||
|
super(SRNLabelDecode, self).__init__(character_dict_path,
|
||||||
|
character_type, use_space_char)
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
pred = preds['predict']
|
||||||
|
char_num = len(self.character_str) + 2
|
||||||
|
if isinstance(pred, paddle.Tensor):
|
||||||
|
pred = pred.numpy()
|
||||||
|
pred = np.reshape(pred, [-1, char_num])
|
||||||
|
|
||||||
|
preds_idx = np.argmax(pred, axis=1)
|
||||||
|
preds_prob = np.max(pred, axis=1)
|
||||||
|
|
||||||
|
preds_idx = np.reshape(preds_idx, [-1, 25])
|
||||||
|
|
||||||
|
preds_prob = np.reshape(preds_prob, [-1, 25])
|
||||||
|
|
||||||
|
text = self.decode(preds_idx, preds_prob)
|
||||||
|
|
||||||
|
if label is None:
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
return text
|
||||||
|
label = self.decode(label)
|
||||||
|
return text, label
|
||||||
|
|
||||||
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list)))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def add_special_char(self, dict_character):
|
||||||
|
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||||
|
return dict_character
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "unsupport type %s in get_beg_end_flag_idx" \
|
||||||
|
% beg_or_end
|
||||||
|
return idx
|
||||||
|
|
|
@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
|
||||||
from tools.program import load_config, merge_config, ArgsParser
|
from tools.program import load_config, merge_config, ArgsParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config", help="configuration file to use")
|
||||||
|
parser.add_argument(
|
||||||
|
"-o", "--output_path", type=str, default='./output/infer/')
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
config = load_config(FLAGS.config)
|
config = load_config(FLAGS.config)
|
||||||
|
@ -51,14 +59,40 @@ def main():
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
|
||||||
'model_type'] != "det" else [3, 640, 640]
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
model = to_static(
|
other_shape = [
|
||||||
model,
|
|
||||||
input_spec=[
|
|
||||||
paddle.static.InputSpec(
|
paddle.static.InputSpec(
|
||||||
shape=[None] + infer_shape, dtype='float32')
|
shape=[None, 1, 64, 256], dtype='float32'), [
|
||||||
])
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 256, 1],
|
||||||
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
|
shape=[None, 25, 1],
|
||||||
|
dtype="int64"), paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, 25, 25], dtype="int64"),
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None, 8, 25, 25], dtype="int64")
|
||||||
|
]
|
||||||
|
]
|
||||||
|
model = to_static(model, input_spec=other_shape)
|
||||||
|
else:
|
||||||
|
infer_shape = [3, -1, -1]
|
||||||
|
if config['Architecture']['model_type'] == "rec":
|
||||||
|
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||||
|
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||||
|
'Transform'] is not None and config['Architecture'][
|
||||||
|
'Transform']['name'] == 'TPS':
|
||||||
|
logger.info(
|
||||||
|
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
||||||
|
)
|
||||||
|
infer_shape[-1] = 100
|
||||||
|
model = to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(
|
||||||
|
shape=[None] + infer_shape, dtype='float32')
|
||||||
|
])
|
||||||
|
|
||||||
paddle.jit.save(model, save_path)
|
paddle.jit.save(model, save_path)
|
||||||
logger.info('inference model is saved to {}'.format(save_path))
|
logger.info('inference model is saved to {}'.format(save_path))
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@ import numpy as np
|
||||||
import math
|
import math
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
import paddle
|
||||||
|
|
||||||
import tools.infer.utility as utility
|
import tools.infer.utility as utility
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
|
@ -46,6 +47,13 @@ class TextRecognizer(object):
|
||||||
"character_dict_path": args.rec_char_dict_path,
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
"use_space_char": args.use_space_char
|
"use_space_char": args.use_space_char
|
||||||
}
|
}
|
||||||
|
if self.rec_algorithm == "SRN":
|
||||||
|
postprocess_params = {
|
||||||
|
'name': 'SRNLabelDecode',
|
||||||
|
"character_type": args.rec_char_type,
|
||||||
|
"character_dict_path": args.rec_char_dict_path,
|
||||||
|
"use_space_char": args.use_space_char
|
||||||
|
}
|
||||||
self.postprocess_op = build_post_process(postprocess_params)
|
self.postprocess_op = build_post_process(postprocess_params)
|
||||||
self.predictor, self.input_tensor, self.output_tensors = \
|
self.predictor, self.input_tensor, self.output_tensors = \
|
||||||
utility.create_predictor(args, 'rec', logger)
|
utility.create_predictor(args, 'rec', logger)
|
||||||
|
@ -70,6 +78,78 @@ class TextRecognizer(object):
|
||||||
padding_im[:, :, 0:resized_w] = resized_image
|
padding_im[:, :, 0:resized_w] = resized_image
|
||||||
return padding_im
|
return padding_im
|
||||||
|
|
||||||
|
def resize_norm_img_srn(self, img, image_shape):
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
|
||||||
|
img_black = np.zeros((imgH, imgW))
|
||||||
|
im_hei = img.shape[0]
|
||||||
|
im_wid = img.shape[1]
|
||||||
|
|
||||||
|
if im_wid <= im_hei * 1:
|
||||||
|
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||||
|
elif im_wid <= im_hei * 2:
|
||||||
|
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||||
|
elif im_wid <= im_hei * 3:
|
||||||
|
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||||
|
else:
|
||||||
|
img_new = cv2.resize(img, (imgW, imgH))
|
||||||
|
|
||||||
|
img_np = np.asarray(img_new)
|
||||||
|
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_black[:, 0:img_np.shape[1]] = img_np
|
||||||
|
img_black = img_black[:, :, np.newaxis]
|
||||||
|
|
||||||
|
row, col, c = img_black.shape
|
||||||
|
c = 1
|
||||||
|
|
||||||
|
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||||
|
|
||||||
|
def srn_other_inputs(self, image_shape, num_heads, max_text_length):
|
||||||
|
|
||||||
|
imgC, imgH, imgW = image_shape
|
||||||
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
|
||||||
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||||
|
(feature_dim, 1)).astype('int64')
|
||||||
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||||
|
(max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias1 = np.tile(
|
||||||
|
gsrm_slf_attn_bias1,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||||
|
[-1, 1, max_text_length, max_text_length])
|
||||||
|
gsrm_slf_attn_bias2 = np.tile(
|
||||||
|
gsrm_slf_attn_bias2,
|
||||||
|
[1, num_heads, 1, 1]).astype('float32') * [-1e9]
|
||||||
|
|
||||||
|
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||||
|
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||||
|
|
||||||
|
return [
|
||||||
|
encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2
|
||||||
|
]
|
||||||
|
|
||||||
|
def process_image_srn(self, img, image_shape, num_heads, max_text_length):
|
||||||
|
norm_img = self.resize_norm_img_srn(img, image_shape)
|
||||||
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
|
||||||
|
[encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
|
self.srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||||
|
|
||||||
|
gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
|
||||||
|
gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
|
||||||
|
encoder_word_pos = encoder_word_pos.astype(np.int64)
|
||||||
|
gsrm_word_pos = gsrm_word_pos.astype(np.int64)
|
||||||
|
|
||||||
|
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||||
|
gsrm_slf_attn_bias2)
|
||||||
|
|
||||||
def __call__(self, img_list):
|
def __call__(self, img_list):
|
||||||
img_num = len(img_list)
|
img_num = len(img_list)
|
||||||
# Calculate the aspect ratio of all text bars
|
# Calculate the aspect ratio of all text bars
|
||||||
|
@ -93,21 +173,64 @@ class TextRecognizer(object):
|
||||||
wh_ratio = w * 1.0 / h
|
wh_ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||||
for ino in range(beg_img_no, end_img_no):
|
for ino in range(beg_img_no, end_img_no):
|
||||||
# norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
if self.rec_algorithm != "SRN":
|
||||||
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
norm_img = self.resize_norm_img(img_list[indices[ino]],
|
||||||
max_wh_ratio)
|
max_wh_ratio)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
norm_img_batch.append(norm_img)
|
norm_img_batch.append(norm_img)
|
||||||
|
else:
|
||||||
|
norm_img = self.process_image_srn(
|
||||||
|
img_list[indices[ino]], self.rec_image_shape, 8, 25)
|
||||||
|
encoder_word_pos_list = []
|
||||||
|
gsrm_word_pos_list = []
|
||||||
|
gsrm_slf_attn_bias1_list = []
|
||||||
|
gsrm_slf_attn_bias2_list = []
|
||||||
|
encoder_word_pos_list.append(norm_img[1])
|
||||||
|
gsrm_word_pos_list.append(norm_img[2])
|
||||||
|
gsrm_slf_attn_bias1_list.append(norm_img[3])
|
||||||
|
gsrm_slf_attn_bias2_list.append(norm_img[4])
|
||||||
|
norm_img_batch.append(norm_img[0])
|
||||||
norm_img_batch = np.concatenate(norm_img_batch)
|
norm_img_batch = np.concatenate(norm_img_batch)
|
||||||
norm_img_batch = norm_img_batch.copy()
|
norm_img_batch = norm_img_batch.copy()
|
||||||
starttime = time.time()
|
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
if self.rec_algorithm == "SRN":
|
||||||
self.predictor.run()
|
starttime = time.time()
|
||||||
outputs = []
|
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
|
||||||
for output_tensor in self.output_tensors:
|
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
|
||||||
output = output_tensor.copy_to_cpu()
|
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||||
outputs.append(output)
|
gsrm_slf_attn_bias1_list)
|
||||||
preds = outputs[0]
|
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||||
|
gsrm_slf_attn_bias2_list)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
norm_img_batch,
|
||||||
|
encoder_word_pos_list,
|
||||||
|
gsrm_word_pos_list,
|
||||||
|
gsrm_slf_attn_bias1_list,
|
||||||
|
gsrm_slf_attn_bias2_list,
|
||||||
|
]
|
||||||
|
input_names = self.predictor.get_input_names()
|
||||||
|
for i in range(len(input_names)):
|
||||||
|
input_tensor = self.predictor.get_input_handle(input_names[
|
||||||
|
i])
|
||||||
|
input_tensor.copy_from_cpu(inputs[i])
|
||||||
|
self.predictor.run()
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
preds = {"predict": outputs[2]}
|
||||||
|
else:
|
||||||
|
starttime = time.time()
|
||||||
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
|
self.predictor.run()
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for output_tensor in self.output_tensors:
|
||||||
|
output = output_tensor.copy_to_cpu()
|
||||||
|
outputs.append(output)
|
||||||
|
preds = outputs[0]
|
||||||
|
|
||||||
rec_result = self.postprocess_op(preds)
|
rec_result = self.postprocess_op(preds)
|
||||||
for rno in range(len(rec_result)):
|
for rno in range(len(rec_result)):
|
||||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||||
|
|
|
@ -70,7 +70,7 @@ def parse_args():
|
||||||
default="./ppocr/utils/ppocr_keys_v1.txt")
|
default="./ppocr/utils/ppocr_keys_v1.txt")
|
||||||
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
parser.add_argument("--use_space_char", type=str2bool, default=True)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vis_font_path", type=str, default="./doc/simfang.ttf")
|
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
|
||||||
parser.add_argument("--drop_score", type=float, default=0.5)
|
parser.add_argument("--drop_score", type=float, default=0.5)
|
||||||
|
|
||||||
# params for text classifier
|
# params for text classifier
|
||||||
|
|
|
@ -62,7 +62,13 @@ def main():
|
||||||
elif op_name in ['RecResizeImg']:
|
elif op_name in ['RecResizeImg']:
|
||||||
op[op_name]['infer_mode'] = True
|
op[op_name]['infer_mode'] = True
|
||||||
elif op_name == 'KeepKeys':
|
elif op_name == 'KeepKeys':
|
||||||
op[op_name]['keep_keys'] = ['image']
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
op[op_name]['keep_keys'] = [
|
||||||
|
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
||||||
|
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
op[op_name]['keep_keys'] = ['image']
|
||||||
transforms.append(op)
|
transforms.append(op)
|
||||||
global_config['infer_mode'] = True
|
global_config['infer_mode'] = True
|
||||||
ops = create_operators(transforms, global_config)
|
ops = create_operators(transforms, global_config)
|
||||||
|
@ -74,10 +80,25 @@ def main():
|
||||||
img = f.read()
|
img = f.read()
|
||||||
data = {'image': img}
|
data = {'image': img}
|
||||||
batch = transform(data, ops)
|
batch = transform(data, ops)
|
||||||
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
|
||||||
|
gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
|
||||||
|
gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
|
||||||
|
gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
|
||||||
|
|
||||||
|
others = [
|
||||||
|
paddle.to_tensor(encoder_word_pos_list),
|
||||||
|
paddle.to_tensor(gsrm_word_pos_list),
|
||||||
|
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||||
|
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||||
|
]
|
||||||
|
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
preds = model(images)
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
preds = model(images, others)
|
||||||
|
else:
|
||||||
|
preds = model(images)
|
||||||
post_result = post_process_class(preds)
|
post_result = post_process_class(preds)
|
||||||
for rec_reuslt in post_result:
|
for rec_reuslt in post_result:
|
||||||
logger.info('\t result: {}'.format(rec_reuslt))
|
logger.info('\t result: {}'.format(rec_reuslt))
|
||||||
|
|
|
@ -174,6 +174,7 @@ def train(config,
|
||||||
best_model_dict = {main_indicator: 0}
|
best_model_dict = {main_indicator: 0}
|
||||||
best_model_dict.update(pre_best_model_dict)
|
best_model_dict.update(pre_best_model_dict)
|
||||||
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
train_stats = TrainingStats(log_smooth_window, ['lr'])
|
||||||
|
model_average = False
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
|
@ -182,8 +183,8 @@ def train(config,
|
||||||
start_epoch = 1
|
start_epoch = 1
|
||||||
|
|
||||||
for epoch in range(start_epoch, epoch_num + 1):
|
for epoch in range(start_epoch, epoch_num + 1):
|
||||||
if epoch > 0:
|
train_dataloader = build_dataloader(
|
||||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
config, 'Train', device, logger, seed=epoch)
|
||||||
train_batch_cost = 0.0
|
train_batch_cost = 0.0
|
||||||
train_reader_cost = 0.0
|
train_reader_cost = 0.0
|
||||||
batch_sum = 0
|
batch_sum = 0
|
||||||
|
@ -194,7 +195,12 @@ def train(config,
|
||||||
break
|
break
|
||||||
lr = optimizer.get_lr()
|
lr = optimizer.get_lr()
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
preds = model(images)
|
if config['Architecture']['algorithm'] == "SRN":
|
||||||
|
others = batch[-4:]
|
||||||
|
preds = model(images, others)
|
||||||
|
model_average = True
|
||||||
|
else:
|
||||||
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
avg_loss = loss['loss']
|
avg_loss = loss['loss']
|
||||||
avg_loss.backward()
|
avg_loss.backward()
|
||||||
|
@ -212,7 +218,7 @@ def train(config,
|
||||||
stats['lr'] = lr
|
stats['lr'] = lr
|
||||||
train_stats.update(stats)
|
train_stats.update(stats)
|
||||||
|
|
||||||
if cal_metric_during_train: # onlt rec and cls need
|
if cal_metric_during_train: # only rec and cls need
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
post_result = post_process_class(preds, batch[1])
|
post_result = post_process_class(preds, batch[1])
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
|
@ -238,21 +244,28 @@ def train(config,
|
||||||
# eval
|
# eval
|
||||||
if global_step > start_eval_step and \
|
if global_step > start_eval_step and \
|
||||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||||
|
if model_average:
|
||||||
|
Model_Average = paddle.incubate.optimizer.ModelAverage(
|
||||||
|
0.15,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
min_average_window=10000,
|
||||||
|
max_average_window=15625)
|
||||||
|
Model_Average.apply()
|
||||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class)
|
eval_class)
|
||||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||||
logger.info(cur_metirc_str)
|
logger.info(cur_metric_str)
|
||||||
|
|
||||||
# logger metric
|
# logger metric
|
||||||
if vdl_writer is not None:
|
if vdl_writer is not None:
|
||||||
for k, v in cur_metirc.items():
|
for k, v in cur_metric.items():
|
||||||
if isinstance(v, (float, int)):
|
if isinstance(v, (float, int)):
|
||||||
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
vdl_writer.add_scalar('EVAL/{}'.format(k),
|
||||||
cur_metirc[k], global_step)
|
cur_metric[k], global_step)
|
||||||
if cur_metirc[main_indicator] >= best_model_dict[
|
if cur_metric[main_indicator] >= best_model_dict[
|
||||||
main_indicator]:
|
main_indicator]:
|
||||||
best_model_dict.update(cur_metirc)
|
best_model_dict.update(cur_metric)
|
||||||
best_model_dict['best_epoch'] = epoch
|
best_model_dict['best_epoch'] = epoch
|
||||||
save_model(
|
save_model(
|
||||||
model,
|
model,
|
||||||
|
@ -263,7 +276,7 @@ def train(config,
|
||||||
prefix='best_accuracy',
|
prefix='best_accuracy',
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch)
|
||||||
best_str = 'best metirc, {}'.format(', '.join([
|
best_str = 'best metric, {}'.format(', '.join([
|
||||||
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
|
||||||
]))
|
]))
|
||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
|
@ -273,6 +286,7 @@ def train(config,
|
||||||
best_model_dict[main_indicator],
|
best_model_dict[main_indicator],
|
||||||
global_step)
|
global_step)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
optimizer.clear_grad()
|
||||||
batch_start = time.time()
|
batch_start = time.time()
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
save_model(
|
save_model(
|
||||||
|
@ -294,7 +308,7 @@ def train(config,
|
||||||
prefix='iter_epoch_{}'.format(epoch),
|
prefix='iter_epoch_{}'.format(epoch),
|
||||||
best_model_dict=best_model_dict,
|
best_model_dict=best_model_dict,
|
||||||
epoch=epoch)
|
epoch=epoch)
|
||||||
best_str = 'best metirc, {}'.format(', '.join(
|
best_str = 'best metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
|
||||||
logger.info(best_str)
|
logger.info(best_str)
|
||||||
if dist.get_rank() == 0 and vdl_writer is not None:
|
if dist.get_rank() == 0 and vdl_writer is not None:
|
||||||
|
@ -312,8 +326,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
if idx >= len(valid_dataloader):
|
if idx >= len(valid_dataloader):
|
||||||
break
|
break
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
|
others = batch[-4:]
|
||||||
start = time.time()
|
start = time.time()
|
||||||
preds = model(images)
|
preds = model(images, others)
|
||||||
|
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
# Obtain usable results from post-processing methods
|
# Obtain usable results from post-processing methods
|
||||||
|
@ -323,13 +338,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
||||||
eval_class(post_result, batch)
|
eval_class(post_result, batch)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame += len(images)
|
total_frame += len(images)
|
||||||
# Get final metirc,eg. acc or hmean
|
# Get final metric,eg. acc or hmean
|
||||||
metirc = eval_class.get_metric()
|
metric = eval_class.get_metric()
|
||||||
|
|
||||||
pbar.close()
|
pbar.close()
|
||||||
model.train()
|
model.train()
|
||||||
metirc['fps'] = total_frame / total_time
|
metric['fps'] = total_frame / total_time
|
||||||
return metirc
|
return metric
|
||||||
|
|
||||||
|
|
||||||
def preprocess(is_train=False):
|
def preprocess(is_train=False):
|
||||||
|
|
5
train.sh
5
train.sh
|
@ -1,5 +1,2 @@
|
||||||
# for paddle.__version__ >= 2.0rc1
|
# recommended paddle.__version__ == 2.0.0
|
||||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
||||||
|
|
||||||
# for paddle.__version__ < 2.0rc1
|
|
||||||
# python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
|
|
||||||
|
|
Loading…
Reference in New Issue