add_rec_sar, test=dygraph
This commit is contained in:
parent
ffa94415c3
commit
8a95b3352d
doc
ppocr
data/imaug
losses
modeling
postprocess
tools
|
@ -44,6 +44,7 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下:
|
||||
|
||||
|
@ -58,6 +59,6 @@ PaddleOCR基于动态图开源的文本识别算法列表:
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) |
|
||||
|
||||
|SAR|Resnet31| 87.1% | rec_r31_sar | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|
||||
PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。
|
||||
|
|
|
@ -215,6 +215,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -46,6 +46,7 @@ PaddleOCR open-source text recognition algorithms list:
|
|||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
|
||||
- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
|
||||
- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
|
||||
- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
@ -60,5 +61,6 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|
|||
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|
||||
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|
||||
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
|
||||
|SAR|Resnet31| 87.1% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
|
||||
|
||||
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
|
||||
|
|
|
@ -207,6 +207,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
|||
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
|
||||
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
|
||||
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
|
||||
| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
|
||||
|
||||
|
||||
For training Chinese data, it is recommended to use
|
||||
|
|
|
@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
|
|||
from .make_shrink_map import MakeShrinkMap
|
||||
from .random_crop_data import EastRandomCropData, PSERandomCrop
|
||||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, SARRecResizeImg
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
|
|
|
@ -521,3 +521,49 @@ class TableLabelEncode(object):
|
|||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
||||
|
||||
class SARLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SARLabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
|
||||
return dict_character
|
||||
|
||||
def __call__(self, data):
|
||||
text = data['label']
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) >= self.max_text_len - 1:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
target = [self.start_idx] + text + [self.end_idx]
|
||||
padded_text = [self.padding_idx for _ in range(self.max_text_len)]
|
||||
|
||||
padded_text[:len(target)] = target
|
||||
data['label'] = np.array(padded_text)
|
||||
return data
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
|
|
@ -83,6 +83,56 @@ class SRNRecResizeImg(object):
|
|||
return data
|
||||
|
||||
|
||||
class SARRecResizeImg(object):
|
||||
def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.width_downsample_ratio = width_downsample_ratio
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(img, self.image_shape, self.width_downsample_ratio)
|
||||
data['image'] = norm_img
|
||||
data['resized_shape'] = resize_shape
|
||||
data['pad_shape'] = pad_shape
|
||||
data['valid_ratio'] = valid_ratio
|
||||
return data
|
||||
|
||||
|
||||
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
||||
imgC, imgH, imgW_min, imgW_max = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
valid_ratio = 1.0
|
||||
# make sure new_width is an integral multiple of width_divisor.
|
||||
width_divisor = int(1 / width_downsample_ratio)
|
||||
# resize
|
||||
ratio = w / float(h)
|
||||
resize_w = math.ceil(imgH * ratio)
|
||||
if resize_w % width_divisor != 0:
|
||||
resize_w = round(resize_w / width_divisor) * width_divisor
|
||||
if imgW_min is not None:
|
||||
resize_w = max(imgW_min, resize_w)
|
||||
if imgW_max is not None:
|
||||
valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
|
||||
resize_w = min(imgW_max, resize_w)
|
||||
resized_image = cv2.resize(img, (resize_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
# norm
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
resized_image = resized_image[np.newaxis, :]
|
||||
else:
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
resize_shape = resized_image.shape
|
||||
padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
|
||||
padding_im[:, :, 0:resize_w] = resized_image
|
||||
pad_shape = padding_im.shape
|
||||
|
||||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
|
|
|
@ -25,6 +25,7 @@ from .det_sast_loss import SASTLoss
|
|||
from .rec_ctc_loss import CTCLoss
|
||||
from .rec_att_loss import AttentionLoss
|
||||
from .rec_srn_loss import SRNLoss
|
||||
from .rec_sar_loss import SARLoss
|
||||
|
||||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
@ -44,7 +45,7 @@ from .table_att_loss import TableAttentionLoss
|
|||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss', 'SARLoss'
|
||||
]
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -26,8 +26,9 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_resnet_31 import ResNet31
|
||||
support_dict = [
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
|
||||
"MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN", "ResNet31"
|
||||
]
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -26,12 +26,13 @@ def build_head(config):
|
|||
from .rec_ctc_head import CTCHead
|
||||
from .rec_att_head import AttentionHead
|
||||
from .rec_srn_head import SRNHead
|
||||
from .rec_sar_head import SARHead
|
||||
|
||||
# cls head
|
||||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||
'SRNHead', 'PGHead', 'TableAttentionHead', 'SARHead']
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
|
|
@ -25,7 +25,7 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
|||
from .east_postprocess import EASTPostProcess
|
||||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
|
||||
TableLabelDecode
|
||||
TableLabelDecode, SARLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
|
@ -35,7 +35,7 @@ def build_post_process(config, global_config=None):
|
|||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
|
||||
'DistillationCTCLabelDecode', 'TableLabelDecode',
|
||||
'DistillationDBPostProcess'
|
||||
'DistillationDBPostProcess', 'SARLabelDecode'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -15,6 +15,7 @@ import numpy as np
|
|||
import string
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
import re
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
|
@ -454,3 +455,79 @@ class TableLabelDecode(object):
|
|||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
||||
|
||||
class SARLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
character_dict_path=None,
|
||||
character_type='ch',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(SARLabelDecode, self).__init__(character_dict_path,
|
||||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
beg_end_str = "<BOS/EOS>"
|
||||
unknown_str = "<UKN>"
|
||||
padding_str = "<PAD>"
|
||||
dict_character = dict_character + [unknown_str]
|
||||
self.unknown_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [beg_end_str]
|
||||
self.start_idx = len(dict_character) - 1
|
||||
self.end_idx = len(dict_character) - 1
|
||||
dict_character = dict_character + [padding_str]
|
||||
self.padding_idx = len(dict_character) - 1
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if int(text_index[batch_idx][idx]) == int(self.end_idx):
|
||||
if text_prob is None and idx ==0:
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
|
||||
text = text.lower()
|
||||
text = comp.sub('', text)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
|
||||
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label, is_remove_duplicate=False)
|
||||
return text, label
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [self.padding_idx]
|
||||
|
|
|
@ -55,6 +55,7 @@ def main():
|
|||
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_sar = config['Architecture']['algorithm'] == "SAR"
|
||||
if "model_type" in config['Architecture'].keys():
|
||||
model_type = config['Architecture']['model_type']
|
||||
else:
|
||||
|
@ -71,7 +72,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
eval_class, model_type, use_srn, use_sar)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -74,6 +74,10 @@ def main():
|
|||
'image', 'encoder_word_pos', 'gsrm_word_pos',
|
||||
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
|
||||
]
|
||||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
op[op_name]['keep_keys'] = [
|
||||
'image', 'valid_ratio'
|
||||
]
|
||||
else:
|
||||
op[op_name]['keep_keys'] = ['image']
|
||||
transforms.append(op)
|
||||
|
@ -106,11 +110,16 @@ def main():
|
|||
paddle.to_tensor(gsrm_slf_attn_bias1_list),
|
||||
paddle.to_tensor(gsrm_slf_attn_bias2_list)
|
||||
]
|
||||
if config['Architecture']['algorithm'] == "SAR":
|
||||
valid_ratio = np.expand_dims(batch[-1], axis=0)
|
||||
img_metas = [paddle.to_tensor(valid_ratio)]
|
||||
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
if config['Architecture']['algorithm'] == "SRN":
|
||||
preds = model(images, others)
|
||||
elif config['Architecture']['algorithm'] == "SAR":
|
||||
preds = model(images, img_metas)
|
||||
else:
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds)
|
||||
|
|
|
@ -186,6 +186,7 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_sar = config['Architecture']['algorithm'] == 'SAR'
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
|
@ -213,7 +214,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table':
|
||||
if use_srn or model_type == 'table' or use_sar:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -277,7 +278,8 @@ def train(config,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=use_srn)
|
||||
use_srn=use_srn,
|
||||
use_sar=use_sar)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -349,7 +351,8 @@ def eval(model,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=False):
|
||||
use_srn=False,
|
||||
use_sar=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -362,7 +365,7 @@ def eval(model,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if use_srn or model_type == 'table':
|
||||
if use_srn or model_type == 'table' or use_sar:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -398,7 +401,7 @@ def preprocess(is_train=False):
|
|||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||
'CLS', 'PGNet', 'Distillation', 'TableAttn', 'SAR'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
|
|
Loading…
Reference in New Issue