commit
9b2c0e4838
|
@ -123,7 +123,7 @@ class BaseRecLabelEncode(object):
|
|||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
if len(text) > self.max_text_len:
|
||||
if len(text) == 0 or len(text) > self.max_text_len:
|
||||
return None
|
||||
if self.character_type == "en":
|
||||
text = text.lower()
|
||||
|
@ -138,9 +138,6 @@ class BaseRecLabelEncode(object):
|
|||
return None
|
||||
return text_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
@ -160,8 +157,6 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
text = self.encode(text)
|
||||
if text is None:
|
||||
return None
|
||||
if len(text) > self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = text + [0] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
|
@ -195,11 +190,6 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
text = self.encode(text)
|
||||
return text
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
|
|
|
@ -82,7 +82,7 @@ class TextClassifier(object):
|
|||
|
||||
cls_res = [['', 0.0]] * img_num
|
||||
batch_num = self.cls_batch_num
|
||||
predict_time = 0
|
||||
elapse = 0
|
||||
for beg_img_no in range(0, img_num, batch_num):
|
||||
end_img_no = min(img_num, beg_img_no + batch_num)
|
||||
norm_img_batch = []
|
||||
|
@ -107,14 +107,14 @@ class TextClassifier(object):
|
|||
self.predictor.run([norm_img_batch])
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
cls_res = self.postprocess_op(prob_out)
|
||||
elapse = time.time() - starttime
|
||||
elapse += time.time() - starttime
|
||||
for rno in range(len(cls_res)):
|
||||
label, score = cls_res[rno]
|
||||
cls_res[indices[beg_img_no + rno]] = [label, score]
|
||||
if '180' in label and score > self.cls_thresh:
|
||||
img_list[indices[beg_img_no + rno]] = cv2.rotate(
|
||||
img_list[indices[beg_img_no + rno]], 1)
|
||||
return img_list, cls_res, predict_time
|
||||
return img_list, cls_res, elapse
|
||||
|
||||
|
||||
def main(args):
|
||||
|
@ -143,10 +143,10 @@ def main(args):
|
|||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], cls_res[ino]))
|
||||
print("Total predict time for %d images, cost: %.3f" %
|
||||
(len(img_list), predict_time))
|
||||
print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[
|
||||
ino]))
|
||||
print("Total predict time for {} images, cost: {:.3f}".format(
|
||||
len(img_list), predict_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
|
|
|
@ -174,15 +174,15 @@ if __name__ == "__main__":
|
|||
if img is None:
|
||||
logger.info("error in loading image:{}".format(image_file))
|
||||
continue
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
dt_boxes, elapse = text_detector(img)
|
||||
if count > 0:
|
||||
total_time += elapse
|
||||
count += 1
|
||||
print("Predict time of %s:" % image_file, elapse)
|
||||
print("Predict time of {}: {}".format(image_file, elapse))
|
||||
src_im = utility.draw_text_det_res(dt_boxes, image_file)
|
||||
img_name_pure = image_file.split("/")[-1]
|
||||
cv2.imwrite(
|
||||
os.path.join(draw_img_save, "det_res_%s" % img_name_pure), src_im)
|
||||
img_name_pure = os.path.split(image_file)[-1]
|
||||
img_path = os.path.join(draw_img_save,
|
||||
"det_res_{}".format(img_name_pure))
|
||||
cv2.imwrite(img_path, src_im)
|
||||
if count > 1:
|
||||
print("Avg Time:", total_time / (count - 1))
|
||||
|
|
|
@ -115,7 +115,7 @@ class TextRecognizer(object):
|
|||
rec_result = self.postprocess_op(preds)
|
||||
for rno in range(len(rec_result)):
|
||||
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
|
||||
elapse = time.time() - starttime
|
||||
elapse += time.time() - starttime
|
||||
return rec_res, elapse
|
||||
|
||||
|
||||
|
@ -145,9 +145,10 @@ def main(args):
|
|||
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
|
||||
exit()
|
||||
for ino in range(len(img_list)):
|
||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
||||
print("Total predict time for %d images, cost: %.3f" %
|
||||
(len(img_list), predict_time))
|
||||
print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[
|
||||
ino]))
|
||||
print("Total predict time for {} images, cost: {:.3f}".format(
|
||||
len(img_list), predict_time))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -236,7 +236,6 @@ def train(config,
|
|||
train_batch_cost = 0.0
|
||||
train_reader_cost = 0.0
|
||||
batch_sum = 0
|
||||
batch_start = time.time()
|
||||
# eval
|
||||
if global_step > start_eval_step and \
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||
|
@ -275,6 +274,7 @@ def train(config,
|
|||
best_model_dict[main_indicator],
|
||||
global_step)
|
||||
global_step += 1
|
||||
batch_start = time.time()
|
||||
if dist.get_rank() == 0:
|
||||
save_model(
|
||||
model,
|
||||
|
@ -333,20 +333,6 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
|
|||
return metirc
|
||||
|
||||
|
||||
def save_inference_mode(model, config, logger):
|
||||
model.eval()
|
||||
save_path = '{}/infer/{}'.format(config['Global']['save_model_dir'],
|
||||
config['Architecture']['model_type'])
|
||||
if config['Architecture']['model_type'] == 'rec':
|
||||
input_shape = [None, 3, 32, None]
|
||||
jit_model = paddle.jit.to_static(
|
||||
model, input_spec=[paddle.static.InputSpec(input_shape)])
|
||||
paddle.jit.save(jit_model, save_path)
|
||||
logger.info('inference model save to {}'.format(save_path))
|
||||
|
||||
model.train()
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
|
|
@ -89,7 +89,6 @@ def main(config, device, logger, vdl_writer):
|
|||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
program.save_inference_mode(model, config, logger)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
|
|
Loading…
Reference in New Issue