use public function for inference

This commit is contained in:
tink2123 2020-05-12 19:55:16 +08:00
parent 4eb359e8a7
commit d24a39df88
8 changed files with 37 additions and 162 deletions

View File

@ -1,42 +0,0 @@
Global:
algorithm: CRNN
use_gpu: true
epoch_num: 300
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output
save_epoch_step: 3
eval_batch_step: 2000
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
max_text_length: 25
character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights:
infer_img: ./infer_img
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel
Backbone:
function: ppocr.modeling.backbones.rec_mobilenet_v3,MobileNetV3
scale: 0.5
model_name: small
Head:
function: ppocr.modeling.heads.rec_ctc_head,CTCPredict
encoder_type: rnn
SeqRNN:
hidden_size: 48
Loss:
function: ppocr.modeling.losses.rec_ctc_loss,CTCLoss
Optimizer:
function: ppocr.optimizer,AdamDecay
base_lr: 0.001
beta1: 0.9
beta2: 0.999

View File

@ -1,13 +0,0 @@
TrainReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
num_workers: 8
img_set_dir: ./train_data
label_file_path: ./train_data/train_label.txt
EvalReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
img_set_dir: ./train_data
label_file_path: ./train_data/test_label.txt
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader

View File

@ -11,3 +11,4 @@ EvalReader:
TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img

View File

@ -5,8 +5,8 @@ Global:
log_smooth_window: 20
print_batch_step: 10
save_model_dir: output_ic15
save_epoch_step: 3
eval_batch_step: 2000
save_epoch_step: 300
eval_batch_step: 200
train_batch_size_per_card: 256
test_batch_size_per_card: 256
image_shape: [3, 32, 100]
@ -14,9 +14,8 @@ Global:
character_type: ch
character_dict_path: ./ppocr/utils/ic15_dict.txt
loss_type: ctc
reader_yml: ./configs/rec/rec_ic15_reader.yml
pretrain_weights: ./pretrain_models/best_accuracy
infer_img: ./infer_img
reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: ./pretrain_models/CRNN/best_accuracy
Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel

View File

@ -22,6 +22,7 @@ import string
import lmdb
from ppocr.utils.utility import initial_logger
from tools.infer.utility import get_image_file_list
logger = initial_logger()
from .img_tools import process_image, get_img_data
@ -165,15 +166,7 @@ class SimpleReader(object):
def sample_iter_reader():
if self.mode == 'test':
image_file_list = []
if os.path.isfile(self.infer_img):
image_file_list = [self.infer_img]
elif os.path.isdir(self.infer_img):
for single_file in os.listdir(self.infer_img):
if single_file.split('.')[
-1] not in ['bmp', 'jpg', 'jpeg', 'png', 'JPEG', 'JPG', 'PNG']:
continue
image_file_list.append(os.path.join(self.infer_img, single_file))
image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list:
img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2:

View File

@ -1,86 +1,36 @@
J
O
I
N
T
y
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
u
p
q
r
s
e
l
f
1
5
4
9
7
2
8
0
F
m
P
A
B
L
C
K
S
R
E
Y
U
p
d
g
a
t
i
n
h
W
D
u
v
H
V
G
w
M
!
k
c
.
(
)
X
b
-
Q
x
Z
?
@
3
/
%
$
,
'
:
y
z
&
j
0
1
2
3
4
5
6
+
[
]
;
#
q
\
´
É
=
7
8
9

View File

@ -1,10 +0,0 @@
#. /paddle/set_env.sh↩
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export FLAGS_fraction_of_gpu_memory_to_use=1.0
python_bin_dir="/opt/_internal/cpython-3.7.0/bin/"
alias python=$python_bin_dir"python3.7"
alias pip=$python_bin_dir"pip3.7"
alias ipython=$python_bin_dir"ipython3"
export LD_LIBRARY_PATH=/opt/_internal/cpython-3.7.0/lib:$LD_LIBRARY_PATH
export PYTHONPATH=$PYTHONPATH:.
ldconfig

View File

@ -46,7 +46,7 @@ from ppocr.data.reader_main import reader_main
from ppocr.utils.save_load import init_model
from ppocr.utils.character import CharacterOps
from ppocr.utils.utility import create_module
from tools.infer.utility import get_image_file_list
logger = initial_logger()
@ -79,11 +79,8 @@ def main():
init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')()
infer_img = config['Global']['infer_img']
if os.path.isfile(infer_img):
infer_list = [infer_img]
elif os.path.isdir(infer_img):
infer_list = os.listdir(config['Global']['infer_img'])
infer_img = config['TestReader']['infer_img']
infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list)
if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.")