ADD PGnet_v3
This commit is contained in:
parent
bb49e1a53f
commit
310d399b83
|
@ -19,10 +19,11 @@ import random
|
||||||
|
|
||||||
|
|
||||||
class PGDateSet(Dataset):
|
class PGDateSet(Dataset):
|
||||||
def __init__(self, config, mode, logger):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(PGDateSet, self).__init__()
|
super(PGDateSet, self).__init__()
|
||||||
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self.seed = seed
|
||||||
global_config = config['Global']
|
global_config = config['Global']
|
||||||
dataset_config = config[mode]['dataset']
|
dataset_config = config[mode]['dataset']
|
||||||
loader_config = config[mode]['loader']
|
loader_config = config[mode]['loader']
|
||||||
|
@ -36,7 +37,6 @@ class PGDateSet(Dataset):
|
||||||
assert len(
|
assert len(
|
||||||
ratio_list
|
ratio_list
|
||||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||||
# self.data_dir = dataset_config['data_dir']
|
|
||||||
self.do_shuffle = loader_config['shuffle']
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
|
@ -50,6 +50,7 @@ class PGDateSet(Dataset):
|
||||||
|
|
||||||
def shuffle_data_random(self):
|
def shuffle_data_random(self):
|
||||||
if self.do_shuffle:
|
if self.do_shuffle:
|
||||||
|
random.seed(self.seed)
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -122,6 +123,7 @@ class PGDateSet(Dataset):
|
||||||
else:
|
else:
|
||||||
print("Unrecognized data format...")
|
print("Unrecognized data format...")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
random.seed(self.seed)
|
||||||
image_files = random.sample(
|
image_files = random.sample(
|
||||||
image_files, round(len(image_files) * ratio_list[idx]))
|
image_files, round(len(image_files) * ratio_list[idx]))
|
||||||
data_lines.extend(image_files)
|
data_lines.extend(image_files)
|
||||||
|
|
|
@ -113,7 +113,6 @@ class PGPostProcess(object):
|
||||||
all_point_pair_list = []
|
all_point_pair_list = []
|
||||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||||
if len(yx_center_line) == 1:
|
if len(yx_center_line) == 1:
|
||||||
print('the length of tcl point is less than 2, repeat')
|
|
||||||
yx_center_line.append(yx_center_line[-1])
|
yx_center_line.append(yx_center_line[-1])
|
||||||
|
|
||||||
# expand corresponding offset for total-text.
|
# expand corresponding offset for total-text.
|
||||||
|
@ -148,7 +147,6 @@ class PGPostProcess(object):
|
||||||
|
|
||||||
# ndarry: (x, 2)
|
# ndarry: (x, 2)
|
||||||
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
||||||
print('expand along width. {}'.format(detected_poly.shape))
|
|
||||||
detected_poly = expand_poly_along_width(
|
detected_poly = expand_poly_along_width(
|
||||||
detected_poly, shrink_ratio_of_width=0.2)
|
detected_poly, shrink_ratio_of_width=0.2)
|
||||||
detected_poly[:, 0] = np.clip(
|
detected_poly[:, 0] = np.clip(
|
||||||
|
@ -157,7 +155,6 @@ class PGPostProcess(object):
|
||||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||||
|
|
||||||
if len(keep_str) < 2:
|
if len(keep_str) < 2:
|
||||||
print('--> too short, {}'.format(keep_str))
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
keep_str_list.append(keep_str)
|
keep_str_list.append(keep_str)
|
||||||
|
@ -175,20 +172,4 @@ class PGPostProcess(object):
|
||||||
'points': poly_list,
|
'points': poly_list,
|
||||||
'strs': keep_str_list,
|
'strs': keep_str_list,
|
||||||
}
|
}
|
||||||
# visualization
|
|
||||||
# if self.save_visualization:
|
|
||||||
# visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im)
|
|
||||||
# visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im)
|
|
||||||
|
|
||||||
# save detected boxes
|
|
||||||
# txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno'
|
|
||||||
# if not os.path.exists(txt_dir):
|
|
||||||
# os.makedirs(txt_dir)
|
|
||||||
# res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix))
|
|
||||||
# with open(res_file, 'w') as f:
|
|
||||||
# for i_box, box in enumerate(poly_list):
|
|
||||||
# seq_str = keep_str_list[i_box]
|
|
||||||
# box = np.round(box).astype('int32')
|
|
||||||
# box_str = ','.join(str(s) for s in (box.flatten().tolist()))
|
|
||||||
# f.write('{}\t{}\r\n'.format(box_str, seq_str))
|
|
||||||
return data
|
return data
|
||||||
|
|
Loading…
Reference in New Issue