commit
dd6f6f5cf3
|
@ -0,0 +1,102 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 72
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/rec_mv3_tps_bilstm_att/
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0005
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.00001
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RARE
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 0.1
|
||||||
|
model_name: small
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 0.5
|
||||||
|
model_name: large
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 96
|
||||||
|
Head:
|
||||||
|
name: AttentionHead
|
||||||
|
hidden_size: 96
|
||||||
|
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: AttentionLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: AttnLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../training/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 256
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../validation/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 1
|
|
@ -0,0 +1,101 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 400
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 10
|
||||||
|
save_model_dir: ./output/rec/b3_rare_r34_none_gru/
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path:
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0005
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.00000
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: rec
|
||||||
|
algorithm: RARE
|
||||||
|
Transform:
|
||||||
|
name: TPS
|
||||||
|
num_fiducial: 20
|
||||||
|
loc_lr: 0.1
|
||||||
|
model_name: large
|
||||||
|
Backbone:
|
||||||
|
name: ResNet
|
||||||
|
layers: 34
|
||||||
|
Neck:
|
||||||
|
name: SequenceEncoder
|
||||||
|
encoder_type: rnn
|
||||||
|
hidden_size: 256 #96
|
||||||
|
Head:
|
||||||
|
name: AttentionHead # AttentionHead
|
||||||
|
hidden_size: 256 #
|
||||||
|
l2_decay: 0.00001
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: AttentionLoss
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: AttnLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: RecMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../training/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 256
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 8
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: LMDBDateSet
|
||||||
|
data_dir: ../validation/
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- AttnLabelEncode: # Class handling label
|
||||||
|
- RecResizeImg:
|
||||||
|
image_shape: [3, 32, 100]
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 256
|
||||||
|
num_workers: 8
|
|
@ -40,7 +40,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
|
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7](ppocr推荐)
|
||||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
|
||||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
@ -53,6 +53,9 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|
|RARE|MobileNetV3|82.5|rec_mv3_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|RARE|Resnet34_vd|83.6|rec_r34_vd_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||||
|
|
||||||
|
|
||||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||||
|
|
|
@ -201,6 +201,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
||||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||||
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
|
||||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||||
|
|
|
@ -42,7 +42,7 @@ PaddleOCR open-source text recognition algorithms list:
|
||||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
|
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
|
||||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
|
||||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||||
|
@ -55,6 +55,8 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
||||||
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
|CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
|StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)|
|
||||||
|
|RARE|MobileNetV3|82.5|rec_mv3_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|
|RARE|Resnet34_vd|83.6|rec_r34_vd_tps_bilstm_att||[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||||
|
|
||||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||||
|
|
|
@ -195,8 +195,11 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
||||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||||
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
|
||||||
|
|
||||||
For training Chinese data, it is recommended to use
|
For training Chinese data, it is recommended to use
|
||||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||||
co
|
co
|
||||||
|
|
|
@ -199,16 +199,30 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
||||||
super(AttnLabelEncode,
|
super(AttnLabelEncode,
|
||||||
self).__init__(max_text_length, character_dict_path,
|
self).__init__(max_text_length, character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
self.beg_str = "sos"
|
|
||||||
self.end_str = "eos"
|
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def __call__(self, text):
|
def __call__(self, data):
|
||||||
|
text = data['label']
|
||||||
text = self.encode(text)
|
text = self.encode(text)
|
||||||
return text
|
if text is None:
|
||||||
|
return None
|
||||||
|
if len(text) >= self.max_text_len:
|
||||||
|
return None
|
||||||
|
data['length'] = np.array(len(text))
|
||||||
|
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||||
|
- len(text) - 1)
|
||||||
|
data['label'] = np.array(text)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def get_ignored_tokens(self):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end")
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
def get_beg_end_flag_idx(self, beg_or_end):
|
def get_beg_end_flag_idx(self, beg_or_end):
|
||||||
if beg_or_end == "beg":
|
if beg_or_end == "beg":
|
||||||
|
|
|
@ -23,13 +23,15 @@ def build_loss(config):
|
||||||
|
|
||||||
# rec loss
|
# rec loss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
|
from .rec_att_loss import AttentionLoss
|
||||||
from .rec_srn_loss import SRNLoss
|
from .rec_srn_loss import SRNLoss
|
||||||
|
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
|
'SRNLoss'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLoss(nn.Layer):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super(AttentionLoss, self).__init__()
|
||||||
|
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
targets = batch[1].astype("int64")
|
||||||
|
label_lengths = batch[2].astype('int64')
|
||||||
|
batch_size, num_steps, num_classes = predicts.shape[0], predicts.shape[
|
||||||
|
1], predicts.shape[2]
|
||||||
|
assert len(targets.shape) == len(list(predicts.shape)) - 1, \
|
||||||
|
"The target's shape and inputs's shape is [N, d] and [N, num_steps]"
|
||||||
|
|
||||||
|
inputs = paddle.reshape(predicts, [-1, predicts.shape[-1]])
|
||||||
|
targets = paddle.reshape(targets, [-1])
|
||||||
|
|
||||||
|
return {'loss': paddle.sum(self.loss_func(inputs, targets))}
|
|
@ -23,12 +23,14 @@ def build_head(config):
|
||||||
|
|
||||||
# rec head
|
# rec head
|
||||||
from .rec_ctc_head import CTCHead
|
from .rec_ctc_head import CTCHead
|
||||||
|
from .rec_att_head import AttentionHead
|
||||||
from .rec_srn_head import SRNHead
|
from .rec_srn_head import SRNHead
|
||||||
|
|
||||||
# cls head
|
# cls head
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||||
|
'SRNHead'
|
||||||
]
|
]
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,199 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionHead(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||||
|
super(AttentionHead, self).__init__()
|
||||||
|
self.input_size = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_classes = out_channels
|
||||||
|
|
||||||
|
self.attention_cell = AttentionGRUCell(
|
||||||
|
in_channels, hidden_size, out_channels, use_gru=False)
|
||||||
|
self.generator = nn.Linear(hidden_size, out_channels)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
num_steps = batch_max_length
|
||||||
|
|
||||||
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
|
output_hiddens = []
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets[:, i], onehot_dim=self.num_classes)
|
||||||
|
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
probs = self.generator(output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
probs = None
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets, onehot_dim=self.num_classes)
|
||||||
|
(outputs, hidden), alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
probs_step = self.generator(outputs)
|
||||||
|
if probs is None:
|
||||||
|
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||||
|
else:
|
||||||
|
probs = paddle.concat(
|
||||||
|
[probs, paddle.unsqueeze(
|
||||||
|
probs_step, axis=1)], axis=1)
|
||||||
|
next_input = probs_step.argmax(axis=1)
|
||||||
|
targets = next_input
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionGRUCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||||
|
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
||||||
|
return cur_hidden, alpha
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||||
|
super(AttentionLSTM, self).__init__()
|
||||||
|
self.input_size = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_classes = out_channels
|
||||||
|
|
||||||
|
self.attention_cell = AttentionLSTMCell(
|
||||||
|
in_channels, hidden_size, out_channels, use_gru=False)
|
||||||
|
self.generator = nn.Linear(hidden_size, out_channels)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
num_steps = batch_max_length
|
||||||
|
|
||||||
|
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||||
|
(batch_size, self.hidden_size)))
|
||||||
|
output_hiddens = []
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for i in range(num_steps):
|
||||||
|
# one-hot vectors for a i-th char
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets[:, i], onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
probs = self.generator(output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
probs = None
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets, onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
probs_step = self.generator(hidden[0])
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
if probs is None:
|
||||||
|
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||||
|
else:
|
||||||
|
probs = paddle.concat(
|
||||||
|
[probs, paddle.unsqueeze(
|
||||||
|
probs_step, axis=1)], axis=1)
|
||||||
|
|
||||||
|
next_input = probs_step.argmax(axis=1)
|
||||||
|
|
||||||
|
targets = next_input
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTMCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionLSTMCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
if not use_gru:
|
||||||
|
self.rnn = nn.LSTMCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
else:
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
||||||
|
return cur_hidden, alpha
|
|
@ -135,16 +135,62 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super(AttnLabelDecode, self).__init__(character_dict_path,
|
super(AttnLabelDecode, self).__init__(character_dict_path,
|
||||||
character_type, use_space_char)
|
character_type, use_space_char)
|
||||||
self.beg_str = "sos"
|
|
||||||
self.end_str = "eos"
|
|
||||||
|
|
||||||
def add_special_char(self, dict_character):
|
def add_special_char(self, dict_character):
|
||||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
dict_character = dict_character
|
||||||
|
dict_character = [self.beg_str] + dict_character + [self.end_str]
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
def __call__(self, text):
|
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
result_list = []
|
||||||
|
ignored_tokens = self.get_ignored_tokens()
|
||||||
|
[beg_idx, end_idx] = self.get_ignored_tokens()
|
||||||
|
batch_size = len(text_index)
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
char_list = []
|
||||||
|
conf_list = []
|
||||||
|
for idx in range(len(text_index[batch_idx])):
|
||||||
|
if text_index[batch_idx][idx] in ignored_tokens:
|
||||||
|
continue
|
||||||
|
if int(text_index[batch_idx][idx]) == int(end_idx):
|
||||||
|
break
|
||||||
|
if is_remove_duplicate:
|
||||||
|
# only for predict
|
||||||
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
|
batch_idx][idx]:
|
||||||
|
continue
|
||||||
|
char_list.append(self.character[int(text_index[batch_idx][
|
||||||
|
idx])])
|
||||||
|
if text_prob is not None:
|
||||||
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
|
else:
|
||||||
|
conf_list.append(1)
|
||||||
|
text = ''.join(char_list)
|
||||||
|
result_list.append((text, np.mean(conf_list)))
|
||||||
|
return result_list
|
||||||
|
|
||||||
|
def __call__(self, preds, label=None, *args, **kwargs):
|
||||||
|
"""
|
||||||
text = self.decode(text)
|
text = self.decode(text)
|
||||||
return text
|
if label is None:
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
"""
|
||||||
|
if isinstance(preds, paddle.Tensor):
|
||||||
|
preds = preds.numpy()
|
||||||
|
|
||||||
|
preds_idx = preds.argmax(axis=2)
|
||||||
|
preds_prob = preds.max(axis=2)
|
||||||
|
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||||
|
if label is None:
|
||||||
|
return text
|
||||||
|
label = self.decode(label, is_remove_duplicate=False)
|
||||||
|
return text, label
|
||||||
|
|
||||||
def get_ignored_tokens(self):
|
def get_ignored_tokens(self):
|
||||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||||
|
|
|
@ -184,4 +184,4 @@ def main(args):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main(utility.parse_args())
|
main(utility.parse_args())
|
||||||
|
|
Loading…
Reference in New Issue