From d3b609ee094f74e12a7acd1aec77d9a1ef29a658 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Tue, 15 Dec 2020 23:49:50 +0800 Subject: [PATCH 1/4] change DBHead output to dict --- ppocr/losses/det_db_loss.py | 1 + ppocr/modeling/heads/det_db_head.py | 4 ++-- ppocr/postprocess/db_postprocess.py | 6 ++++-- tools/infer/predict_det.py | 6 +++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py index f170f673..3e2aa063 100755 --- a/ppocr/losses/det_db_loss.py +++ b/ppocr/losses/det_db_loss.py @@ -47,6 +47,7 @@ class DBLoss(nn.Layer): negative_ratio=ohem_ratio) def forward(self, predicts, labels): + predicts = predicts['maps'] label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ 1:] shrink_maps = predicts[:, 0, :, :] diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index 49c50ffd..ca18d74a 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -120,9 +120,9 @@ class DBHead(nn.Layer): def forward(self, x): shrink_maps = self.binarize(x) if not self.training: - return shrink_maps + return {'maps': shrink_maps} threshold_maps = self.thresh(x) binary_maps = self.step_function(shrink_maps, threshold_maps) y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1) - return y + return {'maps': y} diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 16c789dc..91729e0a 100755 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -40,7 +40,8 @@ class DBPostProcess(object): self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 - self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]]) + self.dilation_kernel = None if not use_dilation else np.array( + [[1, 1], [1, 1]]) def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): ''' @@ -132,7 +133,8 @@ class DBPostProcess(object): cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] - def __call__(self, pred, shape_list): + def __call__(self, outs_dict, shape_list): + pred = outs_dict['maps'] if isinstance(pred, paddle.Tensor): pred = pred.numpy() pred = pred[:, 0, :, :] diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index d389ca39..b2260662 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -65,12 +65,12 @@ class TextDetector(object): postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio postprocess_params["use_dilation"] = True elif self.det_algorithm == "EAST": - postprocess_params['name'] = 'EASTPostProcess' + postprocess_params['name'] = 'EASTPostProcess' postprocess_params["score_thresh"] = args.det_east_score_thresh postprocess_params["cover_thresh"] = args.det_east_cover_thresh postprocess_params["nms_thresh"] = args.det_east_nms_thresh elif self.det_algorithm == "SAST": - postprocess_params['name'] = 'SASTPostProcess' + postprocess_params['name'] = 'SASTPostProcess' postprocess_params["score_thresh"] = args.det_sast_score_thresh postprocess_params["nms_thresh"] = args.det_sast_nms_thresh self.det_sast_polygon = args.det_sast_polygon @@ -178,7 +178,7 @@ class TextDetector(object): preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] else: - preds = outputs[0] + preds['maps'] = outputs[0] post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] From 2ad0ca44cd3bcc5896da1452185bf0600a15fd4e Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 16 Dec 2020 00:00:03 +0800 Subject: [PATCH 2/4] update db config --- configs/det/det_mv3_db.yml | 8 ++++---- configs/det/det_r50_vd_db.yml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index 36a6f755..a4ed569f 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -2,7 +2,7 @@ Global: use_gpu: true epoch_num: 1200 log_smooth_window: 20 - print_batch_step: 2 + print_batch_step: 10 save_model_dir: ./output/db_mv3/ save_epoch_step: 1200 # evaluation is run every 5000 iterations after the 4000th iteration @@ -39,7 +39,7 @@ Loss: alpha: 5 beta: 10 ohem_ratio: 3 - + Optimizer: name: Adam beta1: 0.9 @@ -100,7 +100,7 @@ Train: loader: shuffle: True drop_last: False - batch_size_per_card: 4 + batch_size_per_card: 16 num_workers: 8 Eval: @@ -128,4 +128,4 @@ Eval: shuffle: False drop_last: False batch_size_per_card: 1 # must be 1 - num_workers: 2 \ No newline at end of file + num_workers: 8 \ No newline at end of file diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml index b70ab750..386a7970 100644 --- a/configs/det/det_r50_vd_db.yml +++ b/configs/det/det_r50_vd_db.yml @@ -6,7 +6,7 @@ Global: save_model_dir: ./output/det_r50_vd/ save_epoch_step: 1200 # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: [5000,4000] + eval_batch_step: [4000,5000] # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False From 41c2af492495c92d6e07cd51078291197d31e718 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 16 Dec 2020 00:03:50 +0800 Subject: [PATCH 3/4] update predict_det --- tools/infer/predict_det.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index b2260662..ba0adaee 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -177,8 +177,10 @@ class TextDetector(object): preds['f_score'] = outputs[1] preds['f_tco'] = outputs[2] preds['f_tvo'] = outputs[3] - else: + elif self.det_algorithm == 'DB': preds['maps'] = outputs[0] + else: + raise NotImplementedError post_result = self.postprocess_op(preds, shape_list) dt_boxes = post_result[0]['points'] From a8d1f2db94f0ebd2fe2c48527b71379915e3d2f8 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Wed, 16 Dec 2020 13:06:48 +0800 Subject: [PATCH 4/4] update --- configs/det/det_mv3_db.yml | 4 ++-- configs/det/det_r50_vd_db.yml | 4 ++-- ppocr/losses/det_db_loss.py | 8 ++++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/det/det_mv3_db.yml b/configs/det/det_mv3_db.yml index a4ed569f..5c8a0923 100644 --- a/configs/det/det_mv3_db.yml +++ b/configs/det/det_mv3_db.yml @@ -5,8 +5,8 @@ Global: print_batch_step: 10 save_model_dir: ./output/db_mv3/ save_epoch_step: 1200 - # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: [4000, 5000] + # evaluation is run every 2000 iterations + eval_batch_step: [0, 2000] # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/det_r50_vd_db.yml index 386a7970..f1188fe3 100644 --- a/configs/det/det_r50_vd_db.yml +++ b/configs/det/det_r50_vd_db.yml @@ -5,8 +5,8 @@ Global: print_batch_step: 10 save_model_dir: ./output/det_r50_vd/ save_epoch_step: 1200 - # evaluation is run every 5000 iterations after the 4000th iteration - eval_batch_step: [4000,5000] + # evaluation is run every 2000 iterations + eval_batch_step: [0,2000] # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False diff --git a/ppocr/losses/det_db_loss.py b/ppocr/losses/det_db_loss.py index 3e2aa063..b079aabf 100755 --- a/ppocr/losses/det_db_loss.py +++ b/ppocr/losses/det_db_loss.py @@ -47,12 +47,12 @@ class DBLoss(nn.Layer): negative_ratio=ohem_ratio) def forward(self, predicts, labels): - predicts = predicts['maps'] + predict_maps = predicts['maps'] label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[ 1:] - shrink_maps = predicts[:, 0, :, :] - threshold_maps = predicts[:, 1, :, :] - binary_maps = predicts[:, 2, :, :] + shrink_maps = predict_maps[:, 0, :, :] + threshold_maps = predict_maps[:, 1, :, :] + binary_maps = predict_maps[:, 2, :, :] loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map, label_shrink_mask)