refine code
This commit is contained in:
parent
c5f33b0049
commit
b5de79b2c9
|
@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
|
|||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
pred_strs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
|
@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
|
|||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
pred_strs.append(decoded_str)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
return pred_strs, instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
|
|
@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
|
|||
p_char = p_char[0]
|
||||
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list_slow(
|
||||
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
|
||||
p_score,
|
||||
p_char,
|
||||
p_direction,
|
||||
score_thresh=self.score_thresh,
|
||||
is_backbone=True,
|
||||
is_curved=is_curved)
|
||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||
char_seq_idx_set = []
|
||||
for i in range(len(instance_yxs_list)):
|
||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||
feature_len = [len(feature_seq[0])]
|
||||
featyre_seq = paddle.to_tensor(feature_seq)
|
||||
feature_len = np.array([feature_len]).astype(np.int64)
|
||||
length = paddle.to_tensor(feature_len)
|
||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||
input=featyre_seq, blank=36, input_length=length)
|
||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||
seq_len = seq_pred[1].numpy()[0][0]
|
||||
temp_t = []
|
||||
for c in seq_pred_str[:seq_len]:
|
||||
temp_t.append(c)
|
||||
char_seq_idx_set.append(temp_t)
|
||||
seq_strs = []
|
||||
for char_idx_set in char_seq_idx_set:
|
||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||
|
|
Loading…
Reference in New Issue