add rec_nrtr

This commit is contained in:
Topdu 2021-08-17 13:37:32 +00:00
parent 8762a64f2d
commit 4cb824537a
8 changed files with 6 additions and 122 deletions

View File

@ -3,38 +3,22 @@ Global:
epoch_num: 21
log_smooth_window: 20
print_batch_step: 10
<<<<<<< HEAD
save_model_dir: ./output/rec/nrtr_final/
save_model_dir: ./output/rec/nrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: True
=======
save_model_dir: ./output/rec/piloptimnrtr/
save_epoch_step: 1
# evaluation is run every 2000 iterations
eval_batch_step: [0, 2000]
cal_metric_during_train: False
>>>>>>> 9c67a7f... add rec_nrtr
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
<<<<<<< HEAD
character_dict_path:
character_type: EN_symbol
max_text_length: 25
infer_mode: False
use_space_char: True
=======
character_dict_path: ppocr/utils/dict_99.txt
character_type: dict_99
max_text_length: 25
infer_mode: False
use_space_char: False
>>>>>>> 9c67a7f... add rec_nrtr
save_res_path: ./output/rec/predicts_nrtr.txt
Optimizer:

View File

@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)

View File

@ -30,7 +30,7 @@ class RecMetric(object):
target = target.replace(" ", "")
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred.lower() == target.lower():
if pred == target:
correct_num += 1
all_num += 1
self.correct_num += correct_num
@ -57,4 +57,4 @@ class RecMetric(object):
self.correct_num = 0
self.all_num = 0
self.norm_edit_dis = 0

View File

@ -14,7 +14,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone

View File

@ -27,7 +27,7 @@ def build_backbone(config, model_type):
from .rec_resnet_fpn import ResNetFPN
from .rec_nrtr_mtb import MTB
from .rec_swin import SwinTransformer
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
elif model_type == 'e2e':
from .e2e_resnet_vd_pg import ResNet

View File

@ -7,11 +7,7 @@ from paddle.nn import LayerList
from paddle.nn.initializer import XavierNormal as xavier_uniform_
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
import numpy as np
<<<<<<< HEAD
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
=======
from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim
>>>>>>> 9c67a7f... add rec_nrtr
from paddle.nn.initializer import Constant as constant_
from paddle.nn.initializer import XavierNormal as xavier_normal_

View File

@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)

View File

@ -1,95 +0,0 @@
!
"
#
$
%
&
'
(
)
*
+
,
-
.
/
0
1
2
3
4
5
6
7
8
9
:
;
<
=
>
?
@
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
[
\
]
^
_
`
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
q
r
s
t
u
v
w
x
y
z
{
|
}
~