fix config format
This commit is contained in:
parent
50bcec4661
commit
be5fdae573
|
@ -69,10 +69,9 @@ Metric:
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: PGDataSet
|
name: PGDataSet
|
||||||
data_dir: ./train_data/
|
data_dir: ./train_data/train
|
||||||
label_file_list: [.././train_data/total_text/train/]
|
label_file_list: [.././train_data/total_text/train/]
|
||||||
ratio_list: [1.0]
|
ratio_list: [1.0]
|
||||||
data_format: icdar #two data format: icdar/textnet
|
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
|
@ -94,7 +93,7 @@ Train:
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: PGDataSet
|
name: PGDataSet
|
||||||
data_dir: ./train_data/
|
data_dir: ./train_data/test
|
||||||
label_file_list: [./train_data/total_text/test/]
|
label_file_list: [./train_data/total_text/test/]
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
|
@ -78,7 +78,10 @@ class PGDataSet(Dataset):
|
||||||
file_name = substr[0]
|
file_name = substr[0]
|
||||||
label = substr[1]
|
label = substr[1]
|
||||||
img_path = os.path.join(self.data_dir, file_name)
|
img_path = os.path.join(self.data_dir, file_name)
|
||||||
img_id = int(data_line.split(".")[0][3:])
|
if self.mode.lower() == 'eval':
|
||||||
|
img_id = int(data_line.split(".")[0][7:])
|
||||||
|
else:
|
||||||
|
img_id = 0
|
||||||
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
||||||
if not os.path.exists(img_path):
|
if not os.path.exists(img_path):
|
||||||
raise Exception("{} does not exist!".format(img_path))
|
raise Exception("{} does not exist!".format(img_path))
|
||||||
|
|
|
@ -122,7 +122,7 @@ class TextE2E(object):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
post_result = self.postprocess_op(preds, shape_list)
|
post_result = self.postprocess_op(preds, shape_list)
|
||||||
points, strs = post_result['points'], post_result['strs']
|
points, strs = post_result['points'], post_result['texts']
|
||||||
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
return dt_boxes, strs, elapse
|
return dt_boxes, strs, elapse
|
||||||
|
|
|
@ -103,7 +103,7 @@ def main():
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
post_result = post_process_class(preds, shape_list)
|
post_result = post_process_class(preds, shape_list)
|
||||||
points, strs = post_result['points'], post_result['strs']
|
points, strs = post_result['points'], post_result['texts']
|
||||||
# write resule
|
# write resule
|
||||||
dt_boxes_json = []
|
dt_boxes_json = []
|
||||||
for poly, str in zip(points, strs):
|
for poly, str in zip(points, strs):
|
||||||
|
|
Loading…
Reference in New Issue