change DBHead output to dict
This commit is contained in:
parent
57e6edd97c
commit
d3b609ee09
|
@ -47,6 +47,7 @@ class DBLoss(nn.Layer):
|
|||
negative_ratio=ohem_ratio)
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
predicts = predicts['maps']
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
||||
1:]
|
||||
shrink_maps = predicts[:, 0, :, :]
|
||||
|
|
|
@ -120,9 +120,9 @@ class DBHead(nn.Layer):
|
|||
def forward(self, x):
|
||||
shrink_maps = self.binarize(x)
|
||||
if not self.training:
|
||||
return shrink_maps
|
||||
return {'maps': shrink_maps}
|
||||
|
||||
threshold_maps = self.thresh(x)
|
||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
|
||||
return y
|
||||
return {'maps': y}
|
||||
|
|
|
@ -40,7 +40,8 @@ class DBPostProcess(object):
|
|||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
|
||||
self.dilation_kernel = None if not use_dilation else np.array(
|
||||
[[1, 1], [1, 1]])
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
|
@ -132,7 +133,8 @@ class DBPostProcess(object):
|
|||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
def __call__(self, pred, shape_list):
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
pred = outs_dict['maps']
|
||||
if isinstance(pred, paddle.Tensor):
|
||||
pred = pred.numpy()
|
||||
pred = pred[:, 0, :, :]
|
||||
|
|
|
@ -65,12 +65,12 @@ class TextDetector(object):
|
|||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||
postprocess_params["use_dilation"] = True
|
||||
elif self.det_algorithm == "EAST":
|
||||
postprocess_params['name'] = 'EASTPostProcess'
|
||||
postprocess_params['name'] = 'EASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
||||
elif self.det_algorithm == "SAST":
|
||||
postprocess_params['name'] = 'SASTPostProcess'
|
||||
postprocess_params['name'] = 'SASTPostProcess'
|
||||
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||
self.det_sast_polygon = args.det_sast_polygon
|
||||
|
@ -178,7 +178,7 @@ class TextDetector(object):
|
|||
preds['f_tco'] = outputs[2]
|
||||
preds['f_tvo'] = outputs[3]
|
||||
else:
|
||||
preds = outputs[0]
|
||||
preds['maps'] = outputs[0]
|
||||
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
dt_boxes = post_result[0]['points']
|
||||
|
|
Loading…
Reference in New Issue