fix infer_rec for attention
This commit is contained in:
commit
b722eb56c8
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: CRNN
|
algorithm: CRNN
|
||||||
use_gpu: true
|
use_gpu: false
|
||||||
epoch_num: 3000
|
epoch_num: 3000
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 320]
|
image_shape: [3, 32, 320]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
@ -15,7 +16,7 @@ Global:
|
||||||
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
reader_yml: ./configs/rec/rec_chinese_reader.yml
|
reader_yml: ./configs/rec/rec_chinese_reader.yml
|
||||||
pretrain_weights:
|
pretrain_weights: output/rec_CRNN/rec_mv3_crnn/best_accuracy
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
infer_img:
|
infer_img:
|
||||||
|
|
|
@ -8,13 +8,14 @@ Global:
|
||||||
save_epoch_step: 300
|
save_epoch_step: 300
|
||||||
eval_batch_step: 500
|
eval_batch_step: 500
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
character_type: en
|
character_type: en
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
reader_yml: ./configs/rec/rec_icdar15_reader.yml
|
reader_yml: ./configs/rec/rec_icdar15_reader.yml
|
||||||
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
|
pretrain_weights:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
infer_img:
|
infer_img:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: CRNN
|
algorithm: CRNN
|
||||||
use_gpu: true
|
use_gpu: false
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -8,13 +8,14 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
character_type: en
|
character_type: en
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||||
pretrain_weights: ./output/rec_CRNN/rec_mv3_none_bilstm_ctc/best_accuracy
|
pretrain_weights:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
infer_img:
|
infer_img:
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: RARE
|
algorithm: RARE
|
||||||
use_gpu: true
|
use_gpu: false
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -8,6 +8,7 @@ Global:
|
||||||
save_epoch_step: 3
|
save_epoch_step: 3
|
||||||
eval_batch_step: 2000
|
eval_batch_step: 2000
|
||||||
train_batch_size_per_card: 256
|
train_batch_size_per_card: 256
|
||||||
|
drop_last: true
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
|
|
|
@ -17,6 +17,8 @@ import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
from ppocr.utils.utility import initial_logger
|
||||||
|
logger = initial_logger()
|
||||||
|
|
||||||
from .data_augment import AugmentData
|
from .data_augment import AugmentData
|
||||||
from .random_crop_data import RandomCropData
|
from .random_crop_data import RandomCropData
|
||||||
|
@ -100,6 +102,7 @@ class DBProcessTrain(object):
|
||||||
img_path, gt_label = self.convert_label_infor(label_infor)
|
img_path, gt_label = self.convert_label_infor(label_infor)
|
||||||
imgvalue = cv2.imread(img_path)
|
imgvalue = cv2.imread(img_path)
|
||||||
if imgvalue is None:
|
if imgvalue is None:
|
||||||
|
logger.info("{} does not exist!".format(img_path))
|
||||||
return None
|
return None
|
||||||
data = self.make_data_dict(imgvalue, gt_label)
|
data = self.make_data_dict(imgvalue, gt_label)
|
||||||
data = AugmentData(data)
|
data = AugmentData(data)
|
||||||
|
|
|
@ -43,6 +43,7 @@ class LMDBReader(object):
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
|
self.drop_last = params['drop_last']
|
||||||
else:
|
else:
|
||||||
self.batch_size = params['test_batch_size_per_card']
|
self.batch_size = params['test_batch_size_per_card']
|
||||||
self.infer_img = params['infer_img']
|
self.infer_img = params['infer_img']
|
||||||
|
@ -99,7 +100,7 @@ class LMDBReader(object):
|
||||||
process_id = 0
|
process_id = 0
|
||||||
|
|
||||||
def sample_iter_reader():
|
def sample_iter_reader():
|
||||||
if self.infer_img is not None:
|
if self.mode != 'train' and self.infer_img is not None:
|
||||||
image_file_list = get_image_file_list(self.infer_img)
|
image_file_list = get_image_file_list(self.infer_img)
|
||||||
for single_img in image_file_list:
|
for single_img in image_file_list:
|
||||||
img = cv2.imread(single_img)
|
img = cv2.imread(single_img)
|
||||||
|
@ -146,10 +147,11 @@ class LMDBReader(object):
|
||||||
if len(batch_outs) == self.batch_size:
|
if len(batch_outs) == self.batch_size:
|
||||||
yield batch_outs
|
yield batch_outs
|
||||||
batch_outs = []
|
batch_outs = []
|
||||||
if len(batch_outs) != 0:
|
if not self.drop_last:
|
||||||
yield batch_outs
|
if len(batch_outs) != 0:
|
||||||
|
yield batch_outs
|
||||||
|
|
||||||
if self.infer_img is None:
|
if self.mode != 'train' and self.infer_img is None:
|
||||||
return batch_iter_reader
|
return batch_iter_reader
|
||||||
return sample_iter_reader
|
return sample_iter_reader
|
||||||
|
|
||||||
|
@ -171,6 +173,7 @@ class SimpleReader(object):
|
||||||
self.infer_img = params['infer_img']
|
self.infer_img = params['infer_img']
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
|
self.drop_last = params['drop_last']
|
||||||
else:
|
else:
|
||||||
self.batch_size = params['test_batch_size_per_card']
|
self.batch_size = params['test_batch_size_per_card']
|
||||||
|
|
||||||
|
@ -226,8 +229,9 @@ class SimpleReader(object):
|
||||||
if len(batch_outs) == self.batch_size:
|
if len(batch_outs) == self.batch_size:
|
||||||
yield batch_outs
|
yield batch_outs
|
||||||
batch_outs = []
|
batch_outs = []
|
||||||
if len(batch_outs) != 0:
|
if not self.drop_last:
|
||||||
yield batch_outs
|
if len(batch_outs) != 0:
|
||||||
|
yield batch_outs
|
||||||
|
|
||||||
if self.infer_img is None:
|
if self.infer_img is None:
|
||||||
return batch_iter_reader
|
return batch_iter_reader
|
||||||
|
|
|
@ -51,7 +51,7 @@ def resize_norm_img(img, image_shape):
|
||||||
def resize_norm_img_chinese(img, image_shape):
|
def resize_norm_img_chinese(img, image_shape):
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
# todo: change to 0 and modified image shape
|
# todo: change to 0 and modified image shape
|
||||||
max_wh_ratio = 10
|
max_wh_ratio = 0
|
||||||
h, w = img.shape[0], img.shape[1]
|
h, w = img.shape[0], img.shape[1]
|
||||||
ratio = w * 1.0 / h
|
ratio = w * 1.0 / h
|
||||||
max_wh_ratio = max(max_wh_ratio, ratio)
|
max_wh_ratio = max(max_wh_ratio, ratio)
|
||||||
|
|
|
@ -110,7 +110,11 @@ class RecModel(object):
|
||||||
return loader, outputs
|
return loader, outputs
|
||||||
elif mode == "export":
|
elif mode == "export":
|
||||||
predict = predicts['predict']
|
predict = predicts['predict']
|
||||||
predict = fluid.layers.softmax(predict)
|
if self.loss_type == "ctc":
|
||||||
|
predict = fluid.layers.softmax(predict)
|
||||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||||
else:
|
else:
|
||||||
return loader, {'decoded_out': decoded_out}
|
predict = predicts['predict']
|
||||||
|
if self.loss_type == "ctc":
|
||||||
|
predict = fluid.layers.softmax(predict)
|
||||||
|
return loader, {'decoded_out': decoded_out, 'predicts': predict}
|
||||||
|
|
|
@ -123,6 +123,8 @@ class AttentionPredict(object):
|
||||||
|
|
||||||
full_ids = fluid.layers.fill_constant_batch_size_like(
|
full_ids = fluid.layers.fill_constant_batch_size_like(
|
||||||
input=init_state, shape=[-1, 1], dtype='int64', value=1)
|
input=init_state, shape=[-1, 1], dtype='int64', value=1)
|
||||||
|
full_scores = fluid.layers.fill_constant_batch_size_like(
|
||||||
|
input=init_state, shape=[-1, 1], dtype='float32', value=1)
|
||||||
|
|
||||||
cond = layers.less_than(x=counter, y=array_len)
|
cond = layers.less_than(x=counter, y=array_len)
|
||||||
while_op = layers.While(cond=cond)
|
while_op = layers.While(cond=cond)
|
||||||
|
@ -171,6 +173,9 @@ class AttentionPredict(object):
|
||||||
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
|
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
|
||||||
fluid.layers.assign(new_ids, full_ids)
|
fluid.layers.assign(new_ids, full_ids)
|
||||||
|
|
||||||
|
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
|
||||||
|
fluid.layers.assign(new_scores, full_scores)
|
||||||
|
|
||||||
layers.increment(x=counter, value=1, in_place=True)
|
layers.increment(x=counter, value=1, in_place=True)
|
||||||
|
|
||||||
# update the memories
|
# update the memories
|
||||||
|
@ -184,7 +189,7 @@ class AttentionPredict(object):
|
||||||
length_cond = layers.less_than(x=counter, y=array_len)
|
length_cond = layers.less_than(x=counter, y=array_len)
|
||||||
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
|
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
|
||||||
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
|
layers.logical_and(x=length_cond, y=finish_cond, out=cond)
|
||||||
return full_ids
|
return full_ids, full_scores
|
||||||
|
|
||||||
def __call__(self, inputs, labels=None, mode=None):
|
def __call__(self, inputs, labels=None, mode=None):
|
||||||
encoder_features = self.encoder(inputs)
|
encoder_features = self.encoder(inputs)
|
||||||
|
@ -223,10 +228,10 @@ class AttentionPredict(object):
|
||||||
decoder_size, char_num)
|
decoder_size, char_num)
|
||||||
_, decoded_out = layers.topk(input=predict, k=1)
|
_, decoded_out = layers.topk(input=predict, k=1)
|
||||||
decoded_out = layers.lod_reset(decoded_out, y=label_out)
|
decoded_out = layers.lod_reset(decoded_out, y=label_out)
|
||||||
predicts = {'predict': predict, 'decoded_out': decoded_out}
|
predicts = {'predict':predict, 'decoded_out':decoded_out}
|
||||||
else:
|
else:
|
||||||
ids = self.gru_attention_infer(
|
ids, predict = self.gru_attention_infer(
|
||||||
decoder_boot, self.max_length, char_num, word_vector_dim,
|
decoder_boot, self.max_length, char_num, word_vector_dim,
|
||||||
encoded_vector, encoded_proj, decoder_size)
|
encoded_vector, encoded_proj, decoder_size)
|
||||||
predicts = {'decoded_out': ids}
|
predicts = {'predict':predict, 'decoded_out':ids}
|
||||||
return predicts
|
return predicts
|
||||||
|
|
|
@ -80,26 +80,43 @@ class TextRecognizer(object):
|
||||||
starttime = time.time()
|
starttime = time.time()
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
self.predictor.zero_copy_run()
|
self.predictor.zero_copy_run()
|
||||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
|
||||||
rec_idx_lod = self.output_tensors[0].lod()[0]
|
if args.rec_algorithm != "RARE":
|
||||||
predict_batch = self.output_tensors[1].copy_to_cpu()
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
predict_lod = self.output_tensors[1].lod()[0]
|
rec_idx_lod = self.output_tensors[0].lod()[0]
|
||||||
elapse = time.time() - starttime
|
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||||
predict_time += elapse
|
predict_lod = self.output_tensors[1].lod()[0]
|
||||||
starttime = time.time()
|
elapse = time.time() - starttime
|
||||||
for rno in range(len(rec_idx_lod) - 1):
|
predict_time += elapse
|
||||||
beg = rec_idx_lod[rno]
|
for rno in range(len(rec_idx_lod) - 1):
|
||||||
end = rec_idx_lod[rno + 1]
|
beg = rec_idx_lod[rno]
|
||||||
rec_idx_tmp = rec_idx_batch[beg:end, 0]
|
end = rec_idx_lod[rno + 1]
|
||||||
preds_text = self.char_ops.decode(rec_idx_tmp)
|
rec_idx_tmp = rec_idx_batch[beg:end, 0]
|
||||||
beg = predict_lod[rno]
|
preds_text = self.char_ops.decode(rec_idx_tmp)
|
||||||
end = predict_lod[rno + 1]
|
beg = predict_lod[rno]
|
||||||
probs = predict_batch[beg:end, :]
|
end = predict_lod[rno + 1]
|
||||||
ind = np.argmax(probs, axis=1)
|
probs = predict_batch[beg:end, :]
|
||||||
blank = probs.shape[1]
|
ind = np.argmax(probs, axis=1)
|
||||||
valid_ind = np.where(ind != (blank - 1))[0]
|
blank = probs.shape[1]
|
||||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
valid_ind = np.where(ind != (blank - 1))[0]
|
||||||
rec_res.append([preds_text, score])
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
|
rec_res.append([preds_text, score])
|
||||||
|
else:
|
||||||
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
|
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||||
|
for rno in range(len(rec_idx_batch)):
|
||||||
|
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
|
||||||
|
if len(end_pos) <= 1:
|
||||||
|
preds = rec_idx_batch[rno, 1:]
|
||||||
|
score = np.mean(predict_batch[rno, 1:])
|
||||||
|
else:
|
||||||
|
preds = rec_idx_batch[rno, 1:end_pos[1]]
|
||||||
|
score = np.mean(predict_batch[rno, 1:end_pos[1]])
|
||||||
|
#todo: why index has 2 offset
|
||||||
|
preds = preds - 2
|
||||||
|
preds_text = self.char_ops.decode(preds)
|
||||||
|
rec_res.append([preds_text, score])
|
||||||
|
|
||||||
return rec_res, predict_time
|
return rec_res, predict_time
|
||||||
|
|
||||||
|
|
||||||
|
@ -116,7 +133,13 @@ if __name__ == "__main__":
|
||||||
continue
|
continue
|
||||||
valid_image_file_list.append(image_file)
|
valid_image_file_list.append(image_file)
|
||||||
img_list.append(img)
|
img_list.append(img)
|
||||||
rec_res, predict_time = text_recognizer(img_list)
|
try:
|
||||||
|
rec_res, predict_time = text_recognizer(img_list)
|
||||||
|
except:
|
||||||
|
logger.info(
|
||||||
|
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
|
||||||
|
"Please set --rec_image_shape=input_shape and --rec_char_type='ch' ")
|
||||||
|
exit()
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
||||||
print("Total predict time for %d images:%.3f" %
|
print("Total predict time for %d images:%.3f" %
|
||||||
|
|
|
@ -55,6 +55,7 @@ def main():
|
||||||
program.merge_config(FLAGS.opt)
|
program.merge_config(FLAGS.opt)
|
||||||
logger.info(config)
|
logger.info(config)
|
||||||
char_ops = CharacterOps(config['Global'])
|
char_ops = CharacterOps(config['Global'])
|
||||||
|
loss_type = config['Global']['loss_type']
|
||||||
config['Global']['char_ops'] = char_ops
|
config['Global']['char_ops'] = char_ops
|
||||||
|
|
||||||
# check if set use_gpu=True in paddlepaddle cpu version
|
# check if set use_gpu=True in paddlepaddle cpu version
|
||||||
|
@ -85,29 +86,38 @@ def main():
|
||||||
if len(infer_list) == 0:
|
if len(infer_list) == 0:
|
||||||
logger.info("Can not find img in infer_img dir.")
|
logger.info("Can not find img in infer_img dir.")
|
||||||
for i in range(max_img_num):
|
for i in range(max_img_num):
|
||||||
print("infer_img:", infer_list[i])
|
print("infer_img:%s" % infer_list[i])
|
||||||
img = next(blobs)
|
img = next(blobs)
|
||||||
predict = exe.run(program=eval_prog,
|
predict = exe.run(program=eval_prog,
|
||||||
feed={"image": img},
|
feed={"image": img},
|
||||||
fetch_list=fetch_varname_list,
|
fetch_list=fetch_varname_list,
|
||||||
return_numpy=False)
|
return_numpy=False)
|
||||||
|
if loss_type == "ctc":
|
||||||
preds = np.array(predict[0])
|
preds = np.array(predict[0])
|
||||||
if preds.shape[1] == 1:
|
|
||||||
preds = preds.reshape(-1)
|
preds = preds.reshape(-1)
|
||||||
preds_lod = predict[0].lod()[0]
|
preds_lod = predict[0].lod()[0]
|
||||||
preds_text = char_ops.decode(preds)
|
preds_text = char_ops.decode(preds)
|
||||||
else:
|
probs = np.array(predict[1])
|
||||||
|
ind = np.argmax(probs, axis=1)
|
||||||
|
blank = probs.shape[1]
|
||||||
|
valid_ind = np.where(ind != (blank - 1))[0]
|
||||||
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
|
elif loss_type == "attention":
|
||||||
|
preds = np.array(predict[0])
|
||||||
|
probs = np.array(predict[1])
|
||||||
end_pos = np.where(preds[0, :] == 1)[0]
|
end_pos = np.where(preds[0, :] == 1)[0]
|
||||||
if len(end_pos) <= 1:
|
if len(end_pos) <= 1:
|
||||||
preds_text = preds[0, 1:]
|
preds = preds[0, 1:]
|
||||||
|
score = np.mean(probs[0, 1:])
|
||||||
else:
|
else:
|
||||||
preds_text = preds[0, 1:end_pos[1]]
|
preds = preds[0, 1:end_pos[1]]
|
||||||
preds_text = preds_text.reshape(-1)
|
score = np.mean(probs[0, 1:end_pos[1]])
|
||||||
preds_text = char_ops.decode(preds_text)
|
preds = preds.reshape(-1)
|
||||||
|
preds_text = char_ops.decode(preds)
|
||||||
|
|
||||||
print("\t index:", preds)
|
print("\t index:", preds)
|
||||||
print("\t word :", preds_text)
|
print("\t word :", preds_text)
|
||||||
|
print("\t score :", score)
|
||||||
|
|
||||||
# save for inference model
|
# save for inference model
|
||||||
target_var = []
|
target_var = []
|
||||||
|
|
Loading…
Reference in New Issue