update infer doc and fix yml
This commit is contained in:
parent
80c188785c
commit
9023a5c57a
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: CRNN
|
algorithm: CRNN
|
||||||
use_gpu: false
|
use_gpu: true
|
||||||
epoch_num: 3000
|
epoch_num: 3000
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -16,7 +16,7 @@ Global:
|
||||||
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
reader_yml: ./configs/rec/rec_chinese_reader.yml
|
reader_yml: ./configs/rec/rec_chinese_reader.yml
|
||||||
pretrain_weights: output/rec_CRNN/rec_mv3_crnn/best_accuracy
|
pretrain_weights:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
infer_img:
|
infer_img:
|
||||||
|
|
|
@ -15,7 +15,7 @@ Global:
|
||||||
character_type: en
|
character_type: en
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
reader_yml: ./configs/rec/rec_icdar15_reader.yml
|
reader_yml: ./configs/rec/rec_icdar15_reader.yml
|
||||||
pretrain_weights:
|
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
|
||||||
checkpoints:
|
checkpoints:
|
||||||
save_inference_dir:
|
save_inference_dir:
|
||||||
infer_img:
|
infer_img:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: CRNN
|
algorithm: CRNN
|
||||||
use_gpu: false
|
use_gpu: true
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
Global:
|
Global:
|
||||||
algorithm: RARE
|
algorithm: RARE
|
||||||
use_gpu: false
|
use_gpu: true
|
||||||
epoch_num: 72
|
epoch_num: 72
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
|
@ -12,8 +12,7 @@ Global:
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
character_type: ch
|
character_type: en
|
||||||
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
|
||||||
loss_type: attention
|
loss_type: attention
|
||||||
tps: true
|
tps: true
|
||||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||||
|
|
|
@ -165,6 +165,12 @@ STAR-Net文本识别模型推理,可以执行如下命令:
|
||||||
```
|
```
|
||||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
RARE 文本识别模型推理,可以执行如下命令:
|
||||||
|
```
|
||||||
|
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
|
||||||
|
```
|
||||||
|
|
||||||
![](imgs_words_en/word_336.png)
|
![](imgs_words_en/word_336.png)
|
||||||
|
|
||||||
执行命令后,上面图像的识别结果如下:
|
执行命令后,上面图像的识别结果如下:
|
||||||
|
|
|
@ -32,10 +32,14 @@ class TextRecognizer(object):
|
||||||
self.rec_image_shape = image_shape
|
self.rec_image_shape = image_shape
|
||||||
self.character_type = args.rec_char_type
|
self.character_type = args.rec_char_type
|
||||||
self.rec_batch_num = args.rec_batch_num
|
self.rec_batch_num = args.rec_batch_num
|
||||||
|
self.rec_algorithm = args.rec_algorithm
|
||||||
char_ops_params = {}
|
char_ops_params = {}
|
||||||
char_ops_params["character_type"] = args.rec_char_type
|
char_ops_params["character_type"] = args.rec_char_type
|
||||||
char_ops_params["character_dict_path"] = args.rec_char_dict_path
|
char_ops_params["character_dict_path"] = args.rec_char_dict_path
|
||||||
char_ops_params['loss_type'] = 'ctc'
|
if self.rec_algorithm != "RARE":
|
||||||
|
char_ops_params['loss_type'] = 'ctc'
|
||||||
|
else:
|
||||||
|
char_ops_params['loss_type'] = 'attention'
|
||||||
self.char_ops = CharacterOps(char_ops_params)
|
self.char_ops = CharacterOps(char_ops_params)
|
||||||
|
|
||||||
def resize_norm_img(self, img, max_wh_ratio):
|
def resize_norm_img(self, img, max_wh_ratio):
|
||||||
|
@ -81,7 +85,7 @@ class TextRecognizer(object):
|
||||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||||
self.predictor.zero_copy_run()
|
self.predictor.zero_copy_run()
|
||||||
|
|
||||||
if args.rec_algorithm != "RARE":
|
if self.rec_algorithm != "RARE":
|
||||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
rec_idx_lod = self.output_tensors[0].lod()[0]
|
rec_idx_lod = self.output_tensors[0].lod()[0]
|
||||||
predict_batch = self.output_tensors[1].copy_to_cpu()
|
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||||
|
@ -104,6 +108,8 @@ class TextRecognizer(object):
|
||||||
else:
|
else:
|
||||||
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
|
||||||
predict_batch = self.output_tensors[1].copy_to_cpu()
|
predict_batch = self.output_tensors[1].copy_to_cpu()
|
||||||
|
elapse = time.time() - starttime
|
||||||
|
predict_time += elapse
|
||||||
for rno in range(len(rec_idx_batch)):
|
for rno in range(len(rec_idx_batch)):
|
||||||
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
|
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
|
||||||
if len(end_pos) <= 1:
|
if len(end_pos) <= 1:
|
||||||
|
@ -112,8 +118,6 @@ class TextRecognizer(object):
|
||||||
else:
|
else:
|
||||||
preds = rec_idx_batch[rno, 1:end_pos[1]]
|
preds = rec_idx_batch[rno, 1:end_pos[1]]
|
||||||
score = np.mean(predict_batch[rno, 1:end_pos[1]])
|
score = np.mean(predict_batch[rno, 1:end_pos[1]])
|
||||||
#attenton index has 2 offset: beg and end
|
|
||||||
preds = preds - 2
|
|
||||||
preds_text = self.char_ops.decode(preds)
|
preds_text = self.char_ops.decode(preds)
|
||||||
rec_res.append([preds_text, score])
|
rec_res.append([preds_text, score])
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue