add rec_nrtr
This commit is contained in:
parent
b6f0a90366
commit
1623c17cdc
|
@ -3,22 +3,38 @@ Global:
|
||||||
epoch_num: 21
|
epoch_num: 21
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
<<<<<<< HEAD
|
||||||
save_model_dir: ./output/rec/nrtr_final/
|
save_model_dir: ./output/rec/nrtr_final/
|
||||||
save_epoch_step: 1
|
save_epoch_step: 1
|
||||||
# evaluation is run every 2000 iterations
|
# evaluation is run every 2000 iterations
|
||||||
eval_batch_step: [0, 2000]
|
eval_batch_step: [0, 2000]
|
||||||
cal_metric_during_train: True
|
cal_metric_during_train: True
|
||||||
|
=======
|
||||||
|
save_model_dir: ./output/rec/piloptimnrtr/
|
||||||
|
save_epoch_step: 1
|
||||||
|
# evaluation is run every 2000 iterations
|
||||||
|
eval_batch_step: [0, 2000]
|
||||||
|
cal_metric_during_train: False
|
||||||
|
>>>>>>> 9c67a7f... add rec_nrtr
|
||||||
pretrained_model:
|
pretrained_model:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
use_visualdl: False
|
use_visualdl: False
|
||||||
infer_img: doc/imgs_words_en/word_10.png
|
infer_img: doc/imgs_words_en/word_10.png
|
||||||
# for data or label process
|
# for data or label process
|
||||||
|
<<<<<<< HEAD
|
||||||
character_dict_path:
|
character_dict_path:
|
||||||
character_type: EN_symbol
|
character_type: EN_symbol
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
use_space_char: True
|
use_space_char: True
|
||||||
|
=======
|
||||||
|
character_dict_path: ppocr/utils/dict_99.txt
|
||||||
|
character_type: dict_99
|
||||||
|
max_text_length: 25
|
||||||
|
infer_mode: False
|
||||||
|
use_space_char: False
|
||||||
|
>>>>>>> 9c67a7f... add rec_nrtr
|
||||||
save_res_path: ./output/rec/predicts_nrtr.txt
|
save_res_path: ./output/rec/predicts_nrtr.txt
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
|
|
|
@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||||
|
- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2)
|
||||||
|
|
||||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||||
|
|
||||||
|
@ -58,6 +59,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
||||||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||||
|
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||||
|
|
||||||
|
|
||||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||||
|
|
|
@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
||||||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||||
|
|
||||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||||
|
|
||||||
|
|
|
@ -60,5 +60,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
||||||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||||
|
|NRTR|NRTR_MTB| 84.1% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
|
||||||
|
|
||||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||||
|
|
|
@ -207,7 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
||||||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||||
|
| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
|
||||||
|
|
||||||
For training Chinese data, it is recommended to use
|
For training Chinese data, it is recommended to use
|
||||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
||||||
from .make_shrink_map import MakeShrinkMap
|
from .make_shrink_map import MakeShrinkMap
|
||||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||||
|
|
||||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
|
||||||
from .randaugment import RandAugment
|
from .randaugment import RandAugment
|
||||||
from .operators import *
|
from .operators import *
|
||||||
from .label_ops import *
|
from .label_ops import *
|
||||||
|
|
|
@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
|
||||||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
||||||
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
||||||
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
|
@ -57,6 +57,38 @@ class DecodeImage(object):
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRDecodeImage(object):
|
||||||
|
""" decode image """
|
||||||
|
|
||||||
|
def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
|
||||||
|
self.img_mode = img_mode
|
||||||
|
self.channel_first = channel_first
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
if six.PY2:
|
||||||
|
assert type(img) is str and len(
|
||||||
|
img) > 0, "invalid input 'img' in DecodeImage"
|
||||||
|
else:
|
||||||
|
assert type(img) is bytes and len(
|
||||||
|
img) > 0, "invalid input 'img' in DecodeImage"
|
||||||
|
img = np.frombuffer(img, dtype='uint8')
|
||||||
|
|
||||||
|
img = cv2.imdecode(img, 1)
|
||||||
|
|
||||||
|
if img is None:
|
||||||
|
return None
|
||||||
|
if self.img_mode == 'GRAY':
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
elif self.img_mode == 'RGB':
|
||||||
|
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
|
||||||
|
if self.channel_first:
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
data['image'] = img
|
||||||
|
return data
|
||||||
|
|
||||||
class NormalizeImage(object):
|
class NormalizeImage(object):
|
||||||
""" normalize image such as substract mean, divide std
|
""" normalize image such as substract mean, divide std
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -16,7 +16,7 @@ import math
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
from PIL import Image
|
||||||
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +42,34 @@ class ClsResizeImg(object):
|
||||||
data['image'] = norm_img
|
data['image'] = norm_img
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
class PILResize(object):
|
||||||
|
def __init__(self, image_shape, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
image_pil = Image.fromarray(np.uint8(img))
|
||||||
|
norm_img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
|
||||||
|
norm_img = np.array(norm_img)
|
||||||
|
norm_img = np.expand_dims(norm_img, -1)
|
||||||
|
norm_img = norm_img.transpose((2, 0, 1))
|
||||||
|
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class CVResize(object):
|
||||||
|
def __init__(self, image_shape, **kwargs):
|
||||||
|
self.image_shape = image_shape
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
img = data['image']
|
||||||
|
#print(img)
|
||||||
|
norm_img = cv2.resize(img,self.image_shape)
|
||||||
|
norm_img = np.expand_dims(norm_img, -1)
|
||||||
|
norm_img = norm_img.transpose((2, 0, 1))
|
||||||
|
data['image'] = norm_img.astype(np.float32) / 128. - 1.
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class RecResizeImg(object):
|
class RecResizeImg(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
|
|
@ -25,7 +25,7 @@ from .det_sast_loss import SASTLoss
|
||||||
from .rec_ctc_loss import CTCLoss
|
from .rec_ctc_loss import CTCLoss
|
||||||
from .rec_att_loss import AttentionLoss
|
from .rec_att_loss import AttentionLoss
|
||||||
from .rec_srn_loss import SRNLoss
|
from .rec_srn_loss import SRNLoss
|
||||||
|
from .rec_nrtr_loss import NRTRLoss
|
||||||
# cls loss
|
# cls loss
|
||||||
from .cls_loss import ClsLoss
|
from .cls_loss import ClsLoss
|
||||||
|
|
||||||
|
@ -42,8 +42,8 @@ from .combined_loss import CombinedLoss
|
||||||
def build_loss(config):
|
def build_loss(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']
|
||||||
]
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('loss only support {}'.format(
|
assert module_name in support_dict, Exception('loss only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def cal_performance(pred, tgt):
|
||||||
|
|
||||||
|
pred = pred.max(1)[1]
|
||||||
|
tgt = tgt.contiguous().view(-1)
|
||||||
|
non_pad_mask = tgt.ne(0)
|
||||||
|
n_correct = pred.eq(tgt)
|
||||||
|
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
|
||||||
|
return n_correct
|
||||||
|
|
||||||
|
|
||||||
|
class NRTRLoss(nn.Layer):
|
||||||
|
def __init__(self,smoothing=True, **kwargs):
|
||||||
|
super(NRTRLoss, self).__init__()
|
||||||
|
self.loss_func = nn.CrossEntropyLoss(reduction='mean',ignore_index=0)
|
||||||
|
self.smoothing = smoothing
|
||||||
|
|
||||||
|
def forward(self, pred, batch):
|
||||||
|
pred = pred.reshape([-1, pred.shape[2]])
|
||||||
|
max_len = batch[2].max()
|
||||||
|
tgt = batch[1][:,1:2+max_len]
|
||||||
|
tgt = tgt.reshape([-1] )
|
||||||
|
if self.smoothing:
|
||||||
|
eps = 0.1
|
||||||
|
n_class = pred.shape[1]
|
||||||
|
one_hot = F.one_hot(tgt, pred.shape[1])
|
||||||
|
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
|
||||||
|
log_prb = F.log_softmax(pred, axis=1)
|
||||||
|
non_pad_mask = paddle.not_equal(tgt, paddle.zeros(tgt.shape,dtype='int64'))
|
||||||
|
loss = -(one_hot * log_prb).sum(axis=1)
|
||||||
|
loss = loss.masked_select(non_pad_mask).mean()
|
||||||
|
else:
|
||||||
|
loss = self.loss_func(pred, tgt)
|
||||||
|
return {'loss': loss}
|
|
@ -30,7 +30,7 @@ class RecMetric(object):
|
||||||
target = target.replace(" ", "")
|
target = target.replace(" ", "")
|
||||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||||
len(pred), len(target), 1)
|
len(pred), len(target), 1)
|
||||||
if pred == target:
|
if pred.lower() == target.lower():
|
||||||
correct_num += 1
|
correct_num += 1
|
||||||
all_num += 1
|
all_num += 1
|
||||||
self.correct_num += correct_num
|
self.correct_num += correct_num
|
||||||
|
@ -48,8 +48,8 @@ class RecMetric(object):
|
||||||
'norm_edit_dis': 0,
|
'norm_edit_dis': 0,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
acc = 1.0 * self.correct_num / self.all_num
|
acc = 1.0 * self.correct_num / (self.all_num)
|
||||||
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
|
norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num)
|
||||||
self.reset()
|
self.reset()
|
||||||
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
|
||||||
|
|
||||||
|
@ -57,3 +57,4 @@ class RecMetric(object):
|
||||||
self.correct_num = 0
|
self.correct_num = 0
|
||||||
self.all_num = 0
|
self.all_num = 0
|
||||||
self.norm_edit_dis = 0
|
self.norm_edit_dis = 0
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import paddle
|
||||||
from paddle import nn
|
from paddle import nn
|
||||||
from ppocr.modeling.transforms import build_transform
|
from ppocr.modeling.transforms import build_transform
|
||||||
from ppocr.modeling.backbones import build_backbone
|
from ppocr.modeling.backbones import build_backbone
|
||||||
|
|
|
@ -25,7 +25,10 @@ def build_backbone(config, model_type):
|
||||||
from .rec_mobilenet_v3 import MobileNetV3
|
from .rec_mobilenet_v3 import MobileNetV3
|
||||||
from .rec_resnet_vd import ResNet
|
from .rec_resnet_vd import ResNet
|
||||||
from .rec_resnet_fpn import ResNetFPN
|
from .rec_resnet_fpn import ResNetFPN
|
||||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
from .rec_nrtr_mtb import MTB
|
||||||
|
from .rec_swin import SwinTransformer
|
||||||
|
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']
|
||||||
|
|
||||||
elif model_type == 'e2e':
|
elif model_type == 'e2e':
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
support_dict = ['ResNet']
|
support_dict = ['ResNet']
|
||||||
|
|
|
@ -0,0 +1,365 @@
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle.nn import Linear
|
||||||
|
from paddle.nn.initializer import XavierUniform as xavier_uniform_
|
||||||
|
from paddle.nn.initializer import Constant as constant_
|
||||||
|
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||||
|
|
||||||
|
zeros_ = constant_(value=0.)
|
||||||
|
ones_ = constant_(value=1.)
|
||||||
|
|
||||||
|
class MultiheadAttention(nn.Layer):
|
||||||
|
r"""Allows the model to jointly attend to information
|
||||||
|
from different representation subspaces.
|
||||||
|
See reference: Attention Is All You Need
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||||
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model
|
||||||
|
num_heads: parallel attention layers, or heads
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
||||||
|
super(MultiheadAttention, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||||
|
|
||||||
|
if add_bias_kv:
|
||||||
|
self.bias_k = self.create_parameter(
|
||||||
|
shape=(1, 1, embed_dim), default_initializer=zeros_)
|
||||||
|
self.add_parameter("bias_k", self.bias_k)
|
||||||
|
self.bias_v = self.create_parameter(
|
||||||
|
shape=(1, 1, embed_dim), default_initializer=zeros_)
|
||||||
|
self.add_parameter("bias_v", self.bias_v)
|
||||||
|
else:
|
||||||
|
self.bias_k = self.bias_v = None
|
||||||
|
|
||||||
|
self.add_zero_attn = add_zero_attn
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 2, kernel_size=(1, 1))
|
||||||
|
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim * 3, kernel_size=(1, 1))
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
|
||||||
|
|
||||||
|
xavier_uniform_(self.out_proj.weight)
|
||||||
|
if self.bias_k is not None:
|
||||||
|
xavier_normal_(self.bias_k)
|
||||||
|
if self.bias_v is not None:
|
||||||
|
xavier_normal_(self.bias_v)
|
||||||
|
|
||||||
|
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||||
|
need_weights=True, static_kv=False, attn_mask=None, qkv_ = [False,False,False]):
|
||||||
|
"""
|
||||||
|
Inputs of forward function
|
||||||
|
query: [target length, batch size, embed dim]
|
||||||
|
key: [sequence length, batch size, embed dim]
|
||||||
|
value: [sequence length, batch size, embed dim]
|
||||||
|
key_padding_mask: if True, mask padding based on batch size
|
||||||
|
incremental_state: if provided, previous time steps are cashed
|
||||||
|
need_weights: output attn_output_weights
|
||||||
|
static_kv: key and value are static
|
||||||
|
|
||||||
|
Outputs of forward function
|
||||||
|
attn_output: [target length, batch size, embed dim]
|
||||||
|
attn_output_weights: [batch size, target length, sequence length]
|
||||||
|
"""
|
||||||
|
qkv_same = qkv_[0]
|
||||||
|
kv_same = qkv_[1]
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||||
|
assert key.shape == value.shape
|
||||||
|
|
||||||
|
if qkv_same:
|
||||||
|
# self-attention
|
||||||
|
q, k, v = self._in_proj_qkv(query)
|
||||||
|
elif kv_same:
|
||||||
|
# encoder-decoder attention
|
||||||
|
q = self._in_proj_q(query)
|
||||||
|
if key is None:
|
||||||
|
assert value is None
|
||||||
|
k = v = None
|
||||||
|
else:
|
||||||
|
k, v = self._in_proj_kv(key)
|
||||||
|
else:
|
||||||
|
q = self._in_proj_q(query)
|
||||||
|
k = self._in_proj_k(key)
|
||||||
|
v = self._in_proj_v(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
if self.bias_k is not None:
|
||||||
|
assert self.bias_v is not None
|
||||||
|
self.bias_k = paddle.concat([self.bias_k for i in range(bsz)],axis=1)
|
||||||
|
self.bias_v = paddle.concat([self.bias_v for i in range(bsz)],axis=1)
|
||||||
|
k = paddle.concat([k, self.bias_k])
|
||||||
|
v = paddle.concat([v, self.bias_v])
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = paddle.concat(
|
||||||
|
[key_padding_mask,paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
|
||||||
|
|
||||||
|
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
if k is not None:
|
||||||
|
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
if v is not None:
|
||||||
|
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
src_len = k.shape[1]
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape[0] == bsz
|
||||||
|
assert key_padding_mask.shape[1] == src_len
|
||||||
|
|
||||||
|
if self.add_zero_attn:
|
||||||
|
src_len += 1
|
||||||
|
k = paddle.concat([k, paddle.zeros((k.shape[0], 1) + k.shape[2:],dtype=k.dtype)], axis=1)
|
||||||
|
v = paddle.concat([v, paddle.zeros((v.shape[0], 1) + v.shape[2:],dtype=v.dtype)], axis=1)
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = paddle.concat([attn_mask, paddle.zeros([attn_mask.shape[0], 1],dtype=attn_mask.dtype)], axis=1)
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
key_padding_mask = paddle.concat(
|
||||||
|
[key_padding_mask, paddle.zeros([key_padding_mask.shape[0], 1],dtype=key_padding_mask.dtype)], axis=1)
|
||||||
|
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
|
||||||
|
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_output_weights += attn_mask
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||||
|
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||||
|
y = paddle.where(key==0.,key, y)
|
||||||
|
attn_output_weights += y
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
|
||||||
|
|
||||||
|
attn_output_weights = F.softmax(
|
||||||
|
attn_output_weights.astype('float32'), axis=-1,
|
||||||
|
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
|
||||||
|
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = paddle.bmm(attn_output_weights, v)
|
||||||
|
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
if need_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
|
||||||
|
else:
|
||||||
|
attn_output_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
|
def _in_proj_qkv(self, query):
|
||||||
|
query = query.transpose([1, 2, 0])
|
||||||
|
query = paddle.unsqueeze(query, axis=2)
|
||||||
|
res = self.conv3(query)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res.chunk(3, axis=-1)
|
||||||
|
|
||||||
|
def _in_proj_kv(self, key):
|
||||||
|
key = key.transpose([1, 2, 0])
|
||||||
|
key = paddle.unsqueeze(key, axis=2)
|
||||||
|
res = self.conv2(key)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res.chunk(2, axis=-1)
|
||||||
|
|
||||||
|
def _in_proj_q(self, query):
|
||||||
|
query = query.transpose([1, 2, 0])
|
||||||
|
query = paddle.unsqueeze(query, axis=2)
|
||||||
|
res = self.conv1(query)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_k(self, key):
|
||||||
|
|
||||||
|
key = key.transpose([1, 2, 0])
|
||||||
|
key = paddle.unsqueeze(key, axis=2)
|
||||||
|
res = self.conv1(key)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_v(self, value):
|
||||||
|
|
||||||
|
value = value.transpose([1,2,0])#(1, 2, 0)
|
||||||
|
value = paddle.unsqueeze(value, axis=2)
|
||||||
|
res = self.conv1(value)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttentionOptim(nn.Layer):
|
||||||
|
r"""Allows the model to jointly attend to information
|
||||||
|
from different representation subspaces.
|
||||||
|
See reference: Attention Is All You Need
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||||
|
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dim: total dimension of the model
|
||||||
|
num_heads: parallel attention layers, or heads
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||||
|
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False):
|
||||||
|
super(MultiheadAttentionOptim, self).__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.dropout = dropout
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||||
|
self.scaling = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
|
||||||
|
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
self.conv1 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv2 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
self.conv3 = paddle.nn.Conv2D(in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
|
||||||
|
|
||||||
|
xavier_uniform_(self.out_proj.weight)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None,
|
||||||
|
need_weights=True, static_kv=False, attn_mask=None):
|
||||||
|
"""
|
||||||
|
Inputs of forward function
|
||||||
|
query: [target length, batch size, embed dim]
|
||||||
|
key: [sequence length, batch size, embed dim]
|
||||||
|
value: [sequence length, batch size, embed dim]
|
||||||
|
key_padding_mask: if True, mask padding based on batch size
|
||||||
|
incremental_state: if provided, previous time steps are cashed
|
||||||
|
need_weights: output attn_output_weights
|
||||||
|
static_kv: key and value are static
|
||||||
|
|
||||||
|
Outputs of forward function
|
||||||
|
attn_output: [target length, batch size, embed dim]
|
||||||
|
attn_output_weights: [batch size, target length, sequence length]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
tgt_len, bsz, embed_dim = query.shape
|
||||||
|
assert embed_dim == self.embed_dim
|
||||||
|
assert list(query.shape) == [tgt_len, bsz, embed_dim]
|
||||||
|
assert key.shape == value.shape
|
||||||
|
|
||||||
|
q = self._in_proj_q(query)
|
||||||
|
k = self._in_proj_k(key)
|
||||||
|
v = self._in_proj_v(value)
|
||||||
|
q *= self.scaling
|
||||||
|
|
||||||
|
|
||||||
|
q = q.reshape([tgt_len, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
k = k.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
v = v.reshape([-1, bsz * self.num_heads, self.head_dim]).transpose([1, 0, 2])
|
||||||
|
|
||||||
|
|
||||||
|
src_len = k.shape[1]
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape[0] == bsz
|
||||||
|
assert key_padding_mask.shape[1] == src_len
|
||||||
|
|
||||||
|
|
||||||
|
attn_output_weights = paddle.bmm(q, k.transpose([0,2,1]))
|
||||||
|
assert list(attn_output_weights.shape) == [bsz * self.num_heads, tgt_len, src_len]
|
||||||
|
|
||||||
|
if attn_mask is not None:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
attn_output_weights += attn_mask
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
key = key_padding_mask.unsqueeze(1).unsqueeze(2).astype('float32')
|
||||||
|
|
||||||
|
y = paddle.full(shape=key.shape, dtype='float32', fill_value='-inf')
|
||||||
|
|
||||||
|
y = paddle.where(key==0.,key, y)
|
||||||
|
|
||||||
|
attn_output_weights += y
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz*self.num_heads, tgt_len, src_len])
|
||||||
|
|
||||||
|
attn_output_weights = F.softmax(
|
||||||
|
attn_output_weights.astype('float32'), axis=-1,
|
||||||
|
dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16 else attn_output_weights.dtype)
|
||||||
|
attn_output_weights = F.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
attn_output = paddle.bmm(attn_output_weights, v)
|
||||||
|
assert list(attn_output.shape) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
||||||
|
attn_output = attn_output.transpose([1, 0,2]).reshape([tgt_len, bsz, embed_dim])
|
||||||
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
|
if need_weights:
|
||||||
|
# average attention weights over heads
|
||||||
|
attn_output_weights = attn_output_weights.reshape([bsz, self.num_heads, tgt_len, src_len])
|
||||||
|
attn_output_weights = attn_output_weights.sum(axis=1) / self.num_heads
|
||||||
|
else:
|
||||||
|
attn_output_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_output_weights
|
||||||
|
|
||||||
|
|
||||||
|
def _in_proj_q(self, query):
|
||||||
|
query = query.transpose([1, 2, 0])
|
||||||
|
query = paddle.unsqueeze(query, axis=2)
|
||||||
|
res = self.conv1(query)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_k(self, key):
|
||||||
|
|
||||||
|
key = key.transpose([1, 2, 0])
|
||||||
|
key = paddle.unsqueeze(key, axis=2)
|
||||||
|
res = self.conv2(key)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
||||||
|
|
||||||
|
def _in_proj_v(self, value):
|
||||||
|
|
||||||
|
value = value.transpose([1,2,0])#(1, 2, 0)
|
||||||
|
value = paddle.unsqueeze(value, axis=2)
|
||||||
|
res = self.conv3(value)
|
||||||
|
res = paddle.squeeze(res, axis=2)
|
||||||
|
res = res.transpose([2, 0, 1])
|
||||||
|
return res
|
|
@ -0,0 +1,28 @@
|
||||||
|
from paddle import nn
|
||||||
|
|
||||||
|
class MTB(nn.Layer):
|
||||||
|
def __init__(self, cnn_num, in_channels):
|
||||||
|
super(MTB, self).__init__()
|
||||||
|
self.block = nn.Sequential()
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.cnn_num = cnn_num
|
||||||
|
if self.cnn_num == 2:
|
||||||
|
for i in range(self.cnn_num):
|
||||||
|
self.block.add_sublayer('conv_{}'.format(i), nn.Conv2D(
|
||||||
|
in_channels = in_channels if i == 0 else 32*(2**(i-1)),
|
||||||
|
out_channels = 32*(2**i),
|
||||||
|
kernel_size = 3,
|
||||||
|
stride = 2,
|
||||||
|
padding=1))
|
||||||
|
self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
|
||||||
|
self.block.add_sublayer('bn_{}'.format(i), nn.BatchNorm2D(32*(2**i)))
|
||||||
|
|
||||||
|
def forward(self, images):
|
||||||
|
|
||||||
|
x = self.block(images)
|
||||||
|
if self.cnn_num == 2:
|
||||||
|
# (b, w, h, c)
|
||||||
|
x = x.transpose([0, 3, 2, 1])
|
||||||
|
x_shape = x.shape
|
||||||
|
x = x.reshape([x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
|
||||||
|
return x
|
|
@ -7,7 +7,11 @@ from paddle.nn import LayerList
|
||||||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
<<<<<<< HEAD
|
||||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
|
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
|
||||||
|
=======
|
||||||
|
from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim
|
||||||
|
>>>>>>> 9c67a7f... add rec_nrtr
|
||||||
from paddle.nn.initializer import Constant as constant_
|
from paddle.nn.initializer import Constant as constant_
|
||||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ def build_neck(config):
|
||||||
from .sast_fpn import SASTFPN
|
from .sast_fpn import SASTFPN
|
||||||
from .rnn import SequenceEncoder
|
from .rnn import SequenceEncoder
|
||||||
from .pg_fpn import PGFPN
|
from .pg_fpn import PGFPN
|
||||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN','TFEncoder']
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||||
|
|
|
@ -24,16 +24,15 @@ __all__ = ['build_post_process']
|
||||||
from .db_postprocess import DBPostProcess
|
from .db_postprocess import DBPostProcess
|
||||||
from .east_postprocess import EASTPostProcess
|
from .east_postprocess import EASTPostProcess
|
||||||
from .sast_postprocess import SASTPostProcess
|
from .sast_postprocess import SASTPostProcess
|
||||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode
|
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, NRTRLabelDecode
|
||||||
from .cls_postprocess import ClsPostProcess
|
from .cls_postprocess import ClsPostProcess
|
||||||
from .pg_postprocess import PGPostProcess
|
from .pg_postprocess import PGPostProcess
|
||||||
|
|
||||||
|
|
||||||
def build_post_process(config, global_config=None):
|
def build_post_process(config, global_config=None):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||||
'DistillationCTCLabelDecode'
|
'DistillationCTCLabelDecode', 'NRTRLabelDecode'
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
|
||||||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||||
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||||
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
@ -256,8 +256,7 @@ class AttnLabelDecode(BaseRecLabelDecode):
|
||||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||||
batch_idx][idx]:
|
batch_idx][idx]:
|
||||||
continue
|
continue
|
||||||
char_list.append(self.character[int(text_index[batch_idx][
|
char_list.append(self.character[int(text_index[batch_idx][idx])])
|
||||||
idx])])
|
|
||||||
if text_prob is not None:
|
if text_prob is not None:
|
||||||
conf_list.append(text_prob[batch_idx][idx])
|
conf_list.append(text_prob[batch_idx][idx])
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
!
|
||||||
|
"
|
||||||
|
#
|
||||||
|
$
|
||||||
|
%
|
||||||
|
&
|
||||||
|
'
|
||||||
|
(
|
||||||
|
)
|
||||||
|
*
|
||||||
|
+
|
||||||
|
,
|
||||||
|
-
|
||||||
|
.
|
||||||
|
/
|
||||||
|
0
|
||||||
|
1
|
||||||
|
2
|
||||||
|
3
|
||||||
|
4
|
||||||
|
5
|
||||||
|
6
|
||||||
|
7
|
||||||
|
8
|
||||||
|
9
|
||||||
|
:
|
||||||
|
;
|
||||||
|
<
|
||||||
|
=
|
||||||
|
>
|
||||||
|
?
|
||||||
|
@
|
||||||
|
A
|
||||||
|
B
|
||||||
|
C
|
||||||
|
D
|
||||||
|
E
|
||||||
|
F
|
||||||
|
G
|
||||||
|
H
|
||||||
|
I
|
||||||
|
J
|
||||||
|
K
|
||||||
|
L
|
||||||
|
M
|
||||||
|
N
|
||||||
|
O
|
||||||
|
P
|
||||||
|
Q
|
||||||
|
R
|
||||||
|
S
|
||||||
|
T
|
||||||
|
U
|
||||||
|
V
|
||||||
|
W
|
||||||
|
X
|
||||||
|
Y
|
||||||
|
Z
|
||||||
|
[
|
||||||
|
\
|
||||||
|
]
|
||||||
|
^
|
||||||
|
_
|
||||||
|
`
|
||||||
|
a
|
||||||
|
b
|
||||||
|
c
|
||||||
|
d
|
||||||
|
e
|
||||||
|
f
|
||||||
|
g
|
||||||
|
h
|
||||||
|
i
|
||||||
|
j
|
||||||
|
k
|
||||||
|
l
|
||||||
|
m
|
||||||
|
n
|
||||||
|
o
|
||||||
|
p
|
||||||
|
q
|
||||||
|
r
|
||||||
|
s
|
||||||
|
t
|
||||||
|
u
|
||||||
|
v
|
||||||
|
w
|
||||||
|
x
|
||||||
|
y
|
||||||
|
z
|
||||||
|
{
|
||||||
|
|
|
||||||
|
}
|
||||||
|
~
|
||||||
|
|
|
@ -22,6 +22,7 @@ import sys
|
||||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
sys.path.append(__dir__)
|
sys.path.append(__dir__)
|
||||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
from ppocr.data import build_dataloader
|
from ppocr.data import build_dataloader
|
||||||
from ppocr.modeling.architectures import build_model
|
from ppocr.modeling.architectures import build_model
|
||||||
from ppocr.postprocess import build_post_process
|
from ppocr.postprocess import build_post_process
|
||||||
|
@ -30,6 +31,7 @@ from ppocr.utils.save_load import init_model
|
||||||
from ppocr.utils.utility import print_dict
|
from ppocr.utils.utility import print_dict
|
||||||
import tools.program as program
|
import tools.program as program
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
# build dataloader
|
# build dataloader
|
||||||
|
|
|
@ -186,7 +186,7 @@ def train(config,
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
start_epoch = best_model_dict['start_epoch']
|
start_epoch = best_model_dict['start_epoch']
|
||||||
else:
|
else:
|
||||||
|
@ -211,6 +211,9 @@ def train(config,
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
model_average = True
|
model_average = True
|
||||||
|
elif use_nrtr:
|
||||||
|
max_len = batch[2].max()
|
||||||
|
preds = model(images, batch[1][:,:2+max_len])
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -350,13 +353,11 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||||
break
|
break
|
||||||
images = batch[0]
|
images = batch[0]
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
if use_srn:
|
if use_srn:
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
|
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
# Obtain usable results from post-processing methods
|
# Obtain usable results from post-processing methods
|
||||||
post_result = post_process_class(preds, batch[1])
|
post_result = post_process_class(preds, batch[1])
|
||||||
|
@ -386,7 +387,7 @@ def preprocess(is_train=False):
|
||||||
alg = config['Architecture']['algorithm']
|
alg = config['Architecture']['algorithm']
|
||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
'CLS', 'PGNet', 'Distillation'
|
'CLS', 'PGNet', 'Distillation','NRTR'
|
||||||
]
|
]
|
||||||
|
|
||||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue