add rec_nrtr
This commit is contained in:
parent
8762a64f2d
commit
4cb824537a
|
@ -3,38 +3,22 @@ Global:
|
|||
epoch_num: 21
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
<<<<<<< HEAD
|
||||
save_model_dir: ./output/rec/nrtr_final/
|
||||
save_model_dir: ./output/rec/nrtr/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: True
|
||||
=======
|
||||
save_model_dir: ./output/rec/piloptimnrtr/
|
||||
save_epoch_step: 1
|
||||
# evaluation is run every 2000 iterations
|
||||
eval_batch_step: [0, 2000]
|
||||
cal_metric_during_train: False
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
<<<<<<< HEAD
|
||||
character_dict_path:
|
||||
character_type: EN_symbol
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
=======
|
||||
character_dict_path: ppocr/utils/dict_99.txt
|
||||
character_type: dict_99
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
save_res_path: ./output/rec/predicts_nrtr.txt
|
||||
|
||||
Optimizer:
|
||||
|
|
|
@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
|
|||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
||||
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
||||
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
|
|
@ -30,7 +30,7 @@ class RecMetric(object):
|
|||
target = target.replace(" ", "")
|
||||
norm_edit_dis += Levenshtein.distance(pred, target) / max(
|
||||
len(pred), len(target), 1)
|
||||
if pred.lower() == target.lower():
|
||||
if pred == target:
|
||||
correct_num += 1
|
||||
all_num += 1
|
||||
self.correct_num += correct_num
|
||||
|
@ -57,4 +57,4 @@ class RecMetric(object):
|
|||
self.correct_num = 0
|
||||
self.all_num = 0
|
||||
self.norm_edit_dis = 0
|
||||
|
||||
|
||||
|
|
|
@ -14,7 +14,6 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
|
|
|
@ -27,7 +27,7 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_swin import SwinTransformer
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN','MTB','SwinTransformer']
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
|
||||
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
|
|
|
@ -7,11 +7,7 @@ from paddle.nn import LayerList
|
|||
from paddle.nn.initializer import XavierNormal as xavier_uniform_
|
||||
from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
|
||||
import numpy as np
|
||||
<<<<<<< HEAD
|
||||
from ppocr.modeling.heads.multiheadAttention import MultiheadAttentionOptim
|
||||
=======
|
||||
from ppocr.modeling.backbones.multiheadAttention import MultiheadAttentionOptim
|
||||
>>>>>>> 9c67a7f... add rec_nrtr
|
||||
from paddle.nn.initializer import Constant as constant_
|
||||
from paddle.nn.initializer import XavierNormal as xavier_normal_
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
|
|||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari','dict_99'
|
||||
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
|
|
@ -1,95 +0,0 @@
|
|||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
)
|
||||
*
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
;
|
||||
<
|
||||
=
|
||||
>
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
[
|
||||
\
|
||||
]
|
||||
^
|
||||
_
|
||||
`
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
{
|
||||
|
|
||||
}
|
||||
~
|
||||
|
Loading…
Reference in New Issue