Merge pull request #1449 from WenmuZhou/tree_doc
[Dygraph] change DBHead output to dict and update db config
This commit is contained in:
commit
f38a22c0b3
|
@ -2,11 +2,11 @@ Global:
|
||||||
use_gpu: true
|
use_gpu: true
|
||||||
epoch_num: 1200
|
epoch_num: 1200
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 2
|
print_batch_step: 10
|
||||||
save_model_dir: ./output/db_mv3/
|
save_model_dir: ./output/db_mv3/
|
||||||
save_epoch_step: 1200
|
save_epoch_step: 1200
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 2000 iterations
|
||||||
eval_batch_step: [4000, 5000]
|
eval_batch_step: [0, 2000]
|
||||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
load_static_weights: True
|
load_static_weights: True
|
||||||
cal_metric_during_train: False
|
cal_metric_during_train: False
|
||||||
|
@ -39,7 +39,7 @@ Loss:
|
||||||
alpha: 5
|
alpha: 5
|
||||||
beta: 10
|
beta: 10
|
||||||
ohem_ratio: 3
|
ohem_ratio: 3
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
|
@ -100,7 +100,7 @@ Train:
|
||||||
loader:
|
loader:
|
||||||
shuffle: True
|
shuffle: True
|
||||||
drop_last: False
|
drop_last: False
|
||||||
batch_size_per_card: 4
|
batch_size_per_card: 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
|
@ -128,4 +128,4 @@ Eval:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
batch_size_per_card: 1 # must be 1
|
batch_size_per_card: 1 # must be 1
|
||||||
num_workers: 2
|
num_workers: 8
|
|
@ -5,8 +5,8 @@ Global:
|
||||||
print_batch_step: 10
|
print_batch_step: 10
|
||||||
save_model_dir: ./output/det_r50_vd/
|
save_model_dir: ./output/det_r50_vd/
|
||||||
save_epoch_step: 1200
|
save_epoch_step: 1200
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 2000 iterations
|
||||||
eval_batch_step: [5000,4000]
|
eval_batch_step: [0,2000]
|
||||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
load_static_weights: True
|
load_static_weights: True
|
||||||
cal_metric_during_train: False
|
cal_metric_during_train: False
|
||||||
|
|
|
@ -47,11 +47,12 @@ class DBLoss(nn.Layer):
|
||||||
negative_ratio=ohem_ratio)
|
negative_ratio=ohem_ratio)
|
||||||
|
|
||||||
def forward(self, predicts, labels):
|
def forward(self, predicts, labels):
|
||||||
|
predict_maps = predicts['maps']
|
||||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
|
||||||
1:]
|
1:]
|
||||||
shrink_maps = predicts[:, 0, :, :]
|
shrink_maps = predict_maps[:, 0, :, :]
|
||||||
threshold_maps = predicts[:, 1, :, :]
|
threshold_maps = predict_maps[:, 1, :, :]
|
||||||
binary_maps = predicts[:, 2, :, :]
|
binary_maps = predict_maps[:, 2, :, :]
|
||||||
|
|
||||||
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
|
loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
|
||||||
label_shrink_mask)
|
label_shrink_mask)
|
||||||
|
|
|
@ -120,9 +120,9 @@ class DBHead(nn.Layer):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
shrink_maps = self.binarize(x)
|
shrink_maps = self.binarize(x)
|
||||||
if not self.training:
|
if not self.training:
|
||||||
return shrink_maps
|
return {'maps': shrink_maps}
|
||||||
|
|
||||||
threshold_maps = self.thresh(x)
|
threshold_maps = self.thresh(x)
|
||||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||||
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
|
y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
|
||||||
return y
|
return {'maps': y}
|
||||||
|
|
|
@ -40,7 +40,8 @@ class DBPostProcess(object):
|
||||||
self.max_candidates = max_candidates
|
self.max_candidates = max_candidates
|
||||||
self.unclip_ratio = unclip_ratio
|
self.unclip_ratio = unclip_ratio
|
||||||
self.min_size = 3
|
self.min_size = 3
|
||||||
self.dilation_kernel = None if not use_dilation else np.array([[1, 1], [1, 1]])
|
self.dilation_kernel = None if not use_dilation else np.array(
|
||||||
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||||
'''
|
'''
|
||||||
|
@ -132,7 +133,8 @@ class DBPostProcess(object):
|
||||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||||
|
|
||||||
def __call__(self, pred, shape_list):
|
def __call__(self, outs_dict, shape_list):
|
||||||
|
pred = outs_dict['maps']
|
||||||
if isinstance(pred, paddle.Tensor):
|
if isinstance(pred, paddle.Tensor):
|
||||||
pred = pred.numpy()
|
pred = pred.numpy()
|
||||||
pred = pred[:, 0, :, :]
|
pred = pred[:, 0, :, :]
|
||||||
|
|
|
@ -65,12 +65,12 @@ class TextDetector(object):
|
||||||
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
|
||||||
postprocess_params["use_dilation"] = True
|
postprocess_params["use_dilation"] = True
|
||||||
elif self.det_algorithm == "EAST":
|
elif self.det_algorithm == "EAST":
|
||||||
postprocess_params['name'] = 'EASTPostProcess'
|
postprocess_params['name'] = 'EASTPostProcess'
|
||||||
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
postprocess_params["score_thresh"] = args.det_east_score_thresh
|
||||||
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
postprocess_params["cover_thresh"] = args.det_east_cover_thresh
|
||||||
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
postprocess_params["nms_thresh"] = args.det_east_nms_thresh
|
||||||
elif self.det_algorithm == "SAST":
|
elif self.det_algorithm == "SAST":
|
||||||
postprocess_params['name'] = 'SASTPostProcess'
|
postprocess_params['name'] = 'SASTPostProcess'
|
||||||
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
postprocess_params["score_thresh"] = args.det_sast_score_thresh
|
||||||
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
|
||||||
self.det_sast_polygon = args.det_sast_polygon
|
self.det_sast_polygon = args.det_sast_polygon
|
||||||
|
@ -177,8 +177,10 @@ class TextDetector(object):
|
||||||
preds['f_score'] = outputs[1]
|
preds['f_score'] = outputs[1]
|
||||||
preds['f_tco'] = outputs[2]
|
preds['f_tco'] = outputs[2]
|
||||||
preds['f_tvo'] = outputs[3]
|
preds['f_tvo'] = outputs[3]
|
||||||
|
elif self.det_algorithm == 'DB':
|
||||||
|
preds['maps'] = outputs[0]
|
||||||
else:
|
else:
|
||||||
preds = outputs[0]
|
raise NotImplementedError
|
||||||
|
|
||||||
post_result = self.postprocess_op(preds, shape_list)
|
post_result = self.postprocess_op(preds, shape_list)
|
||||||
dt_boxes = post_result[0]['points']
|
dt_boxes = post_result[0]['points']
|
||||||
|
|
Loading…
Reference in New Issue