Merge branch 'dygraph' into update_requirements

This commit is contained in:
MissPenguin 2021-04-16 20:06:16 +08:00 committed by GitHub
commit 0e6e30bad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 66 additions and 131 deletions

View File

@ -62,20 +62,21 @@ PostProcess:
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
gt_mat_dir: # the dir of gt_mat gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
Train: Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
label_file_list: [.././train_data/total_text/train/] data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/]
ratio_list: [1.0] ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24 min_crop_size: 24
@ -92,13 +93,12 @@ Train:
Eval: Eval:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncode:
- E2EResizeForTest: - E2EResizeForTest:
max_side_len: 768 max_side_len: 768
- NormalizeImage: - NormalizeImage:
@ -108,7 +108,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] keep_keys: [ 'image', 'shape', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False

View File

@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(BaseRecLabelEncode): class E2ELabelEncode(object):
def __init__(self, def __init__(self, **kwargs):
max_text_length, pass
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
texts = data['strs'] import json
temp_texts = [] label = data['label']
for text in texts: label = json.loads(label)
text = text.lower() nBox = len(label)
text = self.encode(text) boxes, txts, txt_tags = [], [], []
if text is None: for bno in range(0, nBox):
return None box = label[bno]['points']
text = text + [self.pad_num] * (self.max_text_len - len(text)) txt = label[bno]['transcription']
temp_texts.append(text) boxes.append(box)
data['strs'] = np.array(temp_texts) txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data return data

View File

@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): def check_and_validate_polys(self, polys, tags, im_size):
""" """
check so that the text poly is in the same direction, check so that the text poly is in the same direction,
and also filter some invalid polygons and also filter some invalid polygons
@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags: :param tags:
:return: :return:
""" """
(h, w) = xxx_todo_changeme (h, w) = im_size
if polys.shape[0] == 0: if polys.shape[0] == 0:
return polys, np.array([]), np.array([]) return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512 input_size = 512
im = data['image'] im = data['image']
text_polys = data['polys'] text_polys = data['polys']
text_tags = data['tags'] text_tags = data['ignore_tags']
text_strs = data['strs'] text_strs = data['texts']
h, w, _ = im.shape h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys( text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w)) text_polys, text_tags, (h, w))

View File

@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list, self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def extract_polys(self, poly_txt_path): def get_image_info_list(self, file_list, ratio_list):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, data_source in enumerate(file_list): for idx, file in enumerate(file_list):
image_files = [] with open(file, "rb") as f:
if data_format == 'icdar': lines = f.readlines()
image_files = [(data_source, x) for x in if self.mode == "train" or ratio_list[idx] < 1.0:
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed) random.seed(self.seed)
image_files = random.sample( lines = random.sample(lines,
image_files, round(len(image_files) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(image_files) data_lines.extend(lines)
return data_lines return data_lines
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
if self.data_format == 'icdar': data_line = data_line.decode('utf-8')
im_path = os.path.join(data_path, 'rgb', data_line) substr = data_line.strip("\n").split(self.delimiter)
poly_path = os.path.join(data_path, 'poly', file_name = substr[0]
data_line.split('.')[0] + '.txt') label = substr[1]
text_polys, text_tags, text_strs = self.extract_polys(poly_path) img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
img_id = int(data_line.split(".")[0][7:])
else: else:
image_dir = os.path.join(os.path.dirname(data_path), 'image') img_id = 0
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( data = {'img_path': img_path, 'label': label, 'img_id': img_id}
data_line, image_dir) if not os.path.exists(img_path):
img_id = int(data_line.split(".")[0][3:]) raise Exception("{} does not exist!".format(img_path))
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(

View File

@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0] img_id = batch[2][0]
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)

View File

@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict) n = len(pred_dict)
for i in range(n): for i in range(n):
points = pred_dict[i]['points'] points = pred_dict[i]['points']
text = pred_dict[i]['text'] text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, ))) point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text]) det.append([point, text])
return det return det

View File

@ -21,6 +21,7 @@ import math
import numpy as np import numpy as np
from itertools import groupby from itertools import groupby
from cv2.ximgproc import thinning as thin
from skimage.morphology._skeletonize import thin from skimage.morphology._skeletonize import thin

View File

@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit(-1) exit(-1)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data

View File

@ -122,7 +122,7 @@ class TextE2E(object):
else: else:
raise NotImplementedError raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, strs, elapse return dt_boxes, strs, elapse

View File

@ -103,7 +103,7 @@ def main():
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
# write resule # write resule
dt_boxes_json = [] dt_boxes_json = []
for poly, str in zip(points, strs): for poly, str in zip(points, strs):