polish code

This commit is contained in:
tink2123 2020-08-16 16:46:22 +08:00
parent d0d5de7f4d
commit a3b291928b
3 changed files with 9 additions and 3 deletions

View File

@ -32,6 +32,9 @@
| loss_type | 设置 loss 类型 | ctc | 支持两种loss ctc / attention |
| distort | 设置是否使用数据增强 | false | 设置为true时将在训练时随机进行扰动支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
| min_average_window | 平均值计算窗口长度的最小值 | 10000 | \ |
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |

View File

@ -213,6 +213,9 @@ class RecModel(object):
predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
logger.infor(
"Warning! SRN does not support export model currently")
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:
predict = predicts['predict']

View File

@ -69,7 +69,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
return_numpy=False)
preds = np.array(outs[0])
if preds.shape[1] != 1:
if config['Global']['loss_type'] == "attention":
preds, preds_lod = convert_rec_attention_infer_res(preds)
else:
preds_lod = outs[0].lod()[0]
@ -123,8 +123,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict):
" Evaluate lmdb dataset "
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860','IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir']
total_evaluation_data_number = 0
total_correct_number = 0