Merge remote-tracking branch 'PaddlePaddle/dygraph' into dygraph
This commit is contained in:
commit
dab3e8f3f6
|
@ -62,20 +62,21 @@ PostProcess:
|
|||
mode: fast # fast or slow two ways
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
gt_mat_dir: # the dir of gt_mat
|
||||
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
label_file_list: [.././train_data/total_text/train/]
|
||||
data_dir: ./train_data/total_text/train
|
||||
label_file_list: [./train_data/total_text/train/]
|
||||
ratio_list: [1.0]
|
||||
data_format: icdar #two data format: icdar/textnet
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- PGProcessTrain:
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
min_crop_size: 24
|
||||
|
@ -92,13 +93,12 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
data_dir: ./train_data/
|
||||
data_dir: ./train_data/total_text/test
|
||||
label_file_list: [./train_data/total_text/test/]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- E2EResizeForTest:
|
||||
max_side_len: 768
|
||||
- NormalizeImage:
|
||||
|
@ -108,7 +108,7 @@ Eval:
|
|||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
|
||||
keep_keys: [ 'image', 'shape', 'img_id']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
|
|
|
@ -37,7 +37,7 @@ Architecture:
|
|||
name: TPS
|
||||
num_fiducial: 20
|
||||
loc_lr: 0.1
|
||||
model_name: small
|
||||
model_name: large
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 34
|
||||
|
|
|
@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class E2ELabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(E2ELabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
self.pad_num = len(self.dict) # the length to pad
|
||||
class E2ELabelEncode(object):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
texts = data['strs']
|
||||
temp_texts = []
|
||||
for text in texts:
|
||||
text = text.lower()
|
||||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
text = text + [self.pad_num] * (self.max_text_len - len(text))
|
||||
temp_texts.append(text)
|
||||
data['strs'] = np.array(temp_texts)
|
||||
import json
|
||||
label = data['label']
|
||||
label = json.loads(label)
|
||||
nBox = len(label)
|
||||
boxes, txts, txt_tags = [], [], []
|
||||
for bno in range(0, nBox):
|
||||
box = label[bno]['points']
|
||||
txt = label[bno]['transcription']
|
||||
boxes.append(box)
|
||||
txts.append(txt)
|
||||
if txt in ['*', '###']:
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
|
||||
data['polys'] = boxes
|
||||
data['texts'] = txts
|
||||
data['ignore_tags'] = txt_tags
|
||||
return data
|
||||
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ class PGProcessTrain(object):
|
|||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
def check_and_validate_polys(self, polys, tags, im_size):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
|
@ -96,7 +96,7 @@ class PGProcessTrain(object):
|
|||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
(h, w) = im_size
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
|
@ -750,8 +750,8 @@ class PGProcessTrain(object):
|
|||
input_size = 512
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['tags']
|
||||
text_strs = data['strs']
|
||||
text_tags = data['ignore_tags']
|
||||
text_strs = data['texts']
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
|
|
|
@ -29,20 +29,20 @@ class PGDataSet(Dataset):
|
|||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
self.data_format = dataset_config.get('data_format', 'icdar')
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
||||
self.data_format)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
|
@ -55,108 +55,40 @@ class PGDataSet(Dataset):
|
|||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def extract_polys(self, poly_txt_path):
|
||||
"""
|
||||
Read text_polys, txt_tags, txts from give txt file.
|
||||
"""
|
||||
text_polys, txt_tags, txts = [], [], []
|
||||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = list(map(float, poly_str.split(',')))
|
||||
text_polys.append(
|
||||
np.array(
|
||||
poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
txt_tags.append(txt == '###')
|
||||
|
||||
return np.array(list(map(np.array, text_polys))), \
|
||||
np.array(txt_tags, dtype=np.bool), txts
|
||||
|
||||
def extract_info_textnet(self, im_fn, img_dir=''):
|
||||
"""
|
||||
Extract information from line in textnet format.
|
||||
"""
|
||||
info_list = im_fn.split('\t')
|
||||
img_path = ''
|
||||
for ext in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
||||
]:
|
||||
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
|
||||
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
|
||||
break
|
||||
|
||||
if img_path == '':
|
||||
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
|
||||
info_list[0], img_dir))
|
||||
|
||||
nBox = (len(info_list) - 1) // 9
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
for n in range(0, nBox):
|
||||
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
|
||||
txt = info_list[(n + 1) * 9]
|
||||
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
|
||||
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
|
||||
def get_image_info_list(self, file_list, ratio_list):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, data_source in enumerate(file_list):
|
||||
image_files = []
|
||||
if data_format == 'icdar':
|
||||
image_files = [(data_source, x) for x in
|
||||
os.listdir(os.path.join(data_source, 'rgb'))
|
||||
if x.split('.')[-1] in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
|
||||
'tiff', 'gif', 'JPG'
|
||||
]]
|
||||
elif data_format == 'textnet':
|
||||
with open(data_source) as f:
|
||||
image_files = [(data_source, x.strip())
|
||||
for x in f.readlines()]
|
||||
else:
|
||||
print("Unrecognized data format...")
|
||||
exit(-1)
|
||||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||
random.seed(self.seed)
|
||||
image_files = random.sample(
|
||||
image_files, round(len(image_files) * ratio_list[idx]))
|
||||
data_lines.extend(image_files)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
data_lines.extend(lines)
|
||||
return data_lines
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_path, data_line = self.data_lines[file_idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
if self.data_format == 'icdar':
|
||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
if self.mode.lower() == 'eval':
|
||||
img_id = int(data_line.split(".")[0][7:])
|
||||
else:
|
||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||
data_line, image_dir)
|
||||
img_id = int(data_line.split(".")[0][3:])
|
||||
|
||||
data = {
|
||||
'img_path': im_path,
|
||||
'polys': text_polys,
|
||||
'tags': text_tags,
|
||||
'strs': text_strs,
|
||||
'img_id': img_id
|
||||
}
|
||||
img_id = 0
|
||||
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
|
|
|
@ -35,11 +35,11 @@ class E2EMetric(object):
|
|||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
img_id = batch[5][0]
|
||||
img_id = batch[2][0]
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
|
||||
'texts': pred_str
|
||||
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
||||
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
|
|||
n = len(pred_dict)
|
||||
for i in range(n):
|
||||
points = pred_dict[i]['points']
|
||||
text = pred_dict[i]['text']
|
||||
text = pred_dict[i]['texts']
|
||||
point = ",".join(map(str, points.reshape(-1, )))
|
||||
det.append([point, text])
|
||||
return det
|
||||
|
|
|
@ -342,6 +342,7 @@ def generate_pivot_list_curved(p_score,
|
|||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
pred_strs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
|
@ -367,12 +368,13 @@ def generate_pivot_list_curved(p_score,
|
|||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
pred_strs.append(decoded_str)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
return pred_strs, instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
|
|||
src_w, src_h, self.valid_set)
|
||||
data = {
|
||||
'points': poly_list,
|
||||
'strs': keep_str_list,
|
||||
'texts': keep_str_list,
|
||||
}
|
||||
return data
|
||||
|
||||
|
@ -85,32 +85,13 @@ class PGNet_PostProcess(object):
|
|||
p_char = p_char[0]
|
||||
src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list_slow(
|
||||
char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
|
||||
p_score,
|
||||
p_char,
|
||||
p_direction,
|
||||
score_thresh=self.score_thresh,
|
||||
is_backbone=True,
|
||||
is_curved=is_curved)
|
||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||
char_seq_idx_set = []
|
||||
for i in range(len(instance_yxs_list)):
|
||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||
feature_len = [len(feature_seq[0])]
|
||||
featyre_seq = paddle.to_tensor(feature_seq)
|
||||
feature_len = np.array([feature_len]).astype(np.int64)
|
||||
length = paddle.to_tensor(feature_len)
|
||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||
input=featyre_seq, blank=36, input_length=length)
|
||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||
seq_len = seq_pred[1].numpy()[0][0]
|
||||
temp_t = []
|
||||
for c in seq_pred_str[:seq_len]:
|
||||
temp_t.append(c)
|
||||
char_seq_idx_set.append(temp_t)
|
||||
seq_strs = []
|
||||
for char_idx_set in char_seq_idx_set:
|
||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||
|
@ -176,6 +157,6 @@ class PGNet_PostProcess(object):
|
|||
exit(-1)
|
||||
data = {
|
||||
'points': poly_list,
|
||||
'strs': keep_str_list,
|
||||
'texts': keep_str_list,
|
||||
}
|
||||
return data
|
||||
|
|
|
@ -3,9 +3,8 @@ scikit-image==0.17.2
|
|||
imgaug==0.4.0
|
||||
pyclipper
|
||||
lmdb
|
||||
opencv-python==4.2.0.32
|
||||
tqdm
|
||||
numpy
|
||||
visualdl
|
||||
python-Levenshtein
|
||||
opencv-contrib-python
|
||||
opencv-contrib-python==4.2.0.32
|
|
@ -122,7 +122,7 @@ class TextE2E(object):
|
|||
else:
|
||||
raise NotImplementedError
|
||||
post_result = self.postprocess_op(preds, shape_list)
|
||||
points, strs = post_result['points'], post_result['strs']
|
||||
points, strs = post_result['points'], post_result['texts']
|
||||
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
||||
elapse = time.time() - starttime
|
||||
return dt_boxes, strs, elapse
|
||||
|
|
|
@ -103,7 +103,7 @@ def main():
|
|||
images = paddle.to_tensor(images)
|
||||
preds = model(images)
|
||||
post_result = post_process_class(preds, shape_list)
|
||||
points, strs = post_result['points'], post_result['strs']
|
||||
points, strs = post_result['points'], post_result['texts']
|
||||
# write resule
|
||||
dt_boxes_json = []
|
||||
for poly, str in zip(points, strs):
|
||||
|
|
|
@ -196,10 +196,8 @@ def train(config,
|
|||
train_reader_cost = 0.0
|
||||
batch_sum = 0
|
||||
batch_start = time.time()
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
for idx, batch in enumerate(train_dataloader()):
|
||||
train_reader_cost += time.time() - batch_start
|
||||
if idx >= len(train_dataloader):
|
||||
break
|
||||
lr = optimizer.get_lr()
|
||||
images = batch[0]
|
||||
if use_srn:
|
||||
|
|
Loading…
Reference in New Issue