polish code
This commit is contained in:
parent
d0d5de7f4d
commit
a3b291928b
|
@ -32,6 +32,9 @@
|
||||||
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
|
| loss_type | 设置 loss 类型 | ctc | 支持两种loss: ctc / attention |
|
||||||
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
|
| distort | 设置是否使用数据增强 | false | 设置为true时,将在训练时随机进行扰动,支持的扰动操作可阅读[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py) |
|
||||||
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
|
| use_space_char | 设置是否识别空格 | false | 仅在 character_type=ch 时支持空格 |
|
||||||
|
| average_window | ModelAverage优化器中的窗口长度计算比例 | 0.15 | 目前仅应用与SRN |
|
||||||
|
| max_average_window | 平均值计算窗口长度的最大值 | 15625 | 推荐设置为一轮训练中mini-batchs的数目|
|
||||||
|
| min_average_window | 平均值计算窗口长度的最小值 | 10000 | \ |
|
||||||
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
|
| reader_yml | 设置reader配置文件 | ./configs/rec/rec_icdar15_reader.yml | \ |
|
||||||
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
|
| pretrain_weights | 加载预训练模型路径 | ./pretrain_models/CRNN/best_accuracy | \ |
|
||||||
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
|
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
|
||||||
|
|
|
@ -213,6 +213,9 @@ class RecModel(object):
|
||||||
predict = predicts['predict']
|
predict = predicts['predict']
|
||||||
if self.loss_type == "ctc":
|
if self.loss_type == "ctc":
|
||||||
predict = fluid.layers.softmax(predict)
|
predict = fluid.layers.softmax(predict)
|
||||||
|
if self.loss_type == "srn":
|
||||||
|
logger.infor(
|
||||||
|
"Warning! SRN does not support export model currently")
|
||||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||||
else:
|
else:
|
||||||
predict = predicts['predict']
|
predict = predicts['predict']
|
||||||
|
|
|
@ -69,7 +69,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
||||||
return_numpy=False)
|
return_numpy=False)
|
||||||
preds = np.array(outs[0])
|
preds = np.array(outs[0])
|
||||||
|
|
||||||
if preds.shape[1] != 1:
|
if config['Global']['loss_type'] == "attention":
|
||||||
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
||||||
else:
|
else:
|
||||||
preds_lod = outs[0].lod()[0]
|
preds_lod = outs[0].lod()[0]
|
||||||
|
@ -123,8 +123,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
||||||
|
|
||||||
def test_rec_benchmark(exe, config, eval_info_dict):
|
def test_rec_benchmark(exe, config, eval_info_dict):
|
||||||
" Evaluate lmdb dataset "
|
" Evaluate lmdb dataset "
|
||||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860','IC03_867', \
|
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
|
||||||
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
|
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||||
eval_data_dir = config['TestReader']['lmdb_sets_dir']
|
eval_data_dir = config['TestReader']['lmdb_sets_dir']
|
||||||
total_evaluation_data_number = 0
|
total_evaluation_data_number = 0
|
||||||
total_correct_number = 0
|
total_correct_number = 0
|
||||||
|
|
Loading…
Reference in New Issue