fix predict_det not found unclip_ratio

This commit is contained in:
LDOUBLEV 2020-05-25 18:14:13 +08:00
parent e9b6f20d6b
commit 8a6f106237
4 changed files with 13 additions and 4 deletions

View File

@ -64,13 +64,13 @@ PaddleOCR开源的文本检测算法列表
在ICDAR2015文本检测公开数据集上算法效果如下 在ICDAR2015文本检测公开数据集上算法效果如下
|模型|骨干网络|precision|recall|Hmean|下载链接| |模型|骨干网络|precision|recall|Hmean|下载链接|
|-|-|-|-| |-|-|-|-|-|-|
|EAST|ResNet50_vd|88.18%|85.51|86.82%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)| |EAST|ResNet50_vd|88.18%|85.51|86.82%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_east.tar)|
|EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)| |EAST|MobileNetV3|81.67%|79.83%|80.74%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_east.tar)|
|DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)| |DB|ResNet50_vd|83.79%|80.65%|82.19%|[下载链接](https://paddleocr.bj.bcebos.com/det_r50_vd_db.tar)|
|DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)| |DB|MobileNetV3|75.92%|73.18%|74.53%|[下载链接](https://paddleocr.bj.bcebos.com/det_mv3_db.tar)|
* 注: 上述模型的训练和评估设置后处理参数box_thresh=0.6unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化 * 注: 上述DB模型的训练和评估,设置后处理参数box_thresh=0.6unclip_ratio=1.5,使用不同数据集、不同模型训练,可调整这两个参数进行优化
PaddleOCR文本检测算法的训练和使用请参考文档教程中[文本检测模型训练/评估/预测](./doc/detection.md)。 PaddleOCR文本检测算法的训练和使用请参考文档教程中[文本检测模型训练/评估/预测](./doc/detection.md)。

View File

@ -75,6 +75,7 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5 python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
``` ```
* 注box_thresh、unclip_ratio是DB后处理所需要的参数在评估EAST模型时不需要设置
## 测试检测效果 ## 测试检测效果
@ -83,6 +84,12 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./ou
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy"
``` ```
测试DB模型时调整后处理阈值
```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/img_10.jpg" Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
测试文件夹下所有图像的检测效果 测试文件夹下所有图像的检测效果
``` ```
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy" python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o TestReader.infer_img="./doc/imgs_en/" Global.checkpoints="./output/det_db/best_accuracy"

View File

@ -39,6 +39,7 @@ class TextDetector(object):
postprocess_params["thresh"] = args.det_db_thresh postprocess_params["thresh"] = args.det_db_thresh
postprocess_params["box_thresh"] = args.det_db_box_thresh postprocess_params["box_thresh"] = args.det_db_box_thresh
postprocess_params["max_candidates"] = 1000 postprocess_params["max_candidates"] = 1000
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
self.postprocess_op = DBPostProcess(postprocess_params) self.postprocess_op = DBPostProcess(postprocess_params)
elif self.det_algorithm == "EAST": elif self.det_algorithm == "EAST":
self.preprocess_op = EASTProcessTest(preprocess_params) self.preprocess_op = EASTProcessTest(preprocess_params)
@ -143,4 +144,4 @@ if __name__ == "__main__":
img_name_pure = image_file.split("/")[-1] img_name_pure = image_file.split("/")[-1]
cv2.imwrite("./inference_results/det_res_%s" % img_name_pure, src_im) cv2.imwrite("./inference_results/det_res_%s" % img_name_pure, src_im)
if count > 1: if count > 1:
print("Avg Time:", total_time / (count - 1)) print("Avg Time:", total_time / (count - 1))

View File

@ -45,6 +45,7 @@ def parse_args():
#DB parmas #DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
#EAST parmas #EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) parser.add_argument("--det_east_score_thresh", type=float, default=0.8)