fix inference in tps

This commit is contained in:
tink2123 2020-06-03 15:49:18 +08:00
parent b722eb56c8
commit be3a164424
6 changed files with 30 additions and 7 deletions

View File

@ -12,8 +12,10 @@ Global:
test_batch_size_per_card: 256 test_batch_size_per_card: 256
image_shape: [3, 32, 100] image_shape: [3, 32, 100]
max_text_length: 25 max_text_length: 25
character_type: en character_type: ch
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: attention loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:

View File

@ -14,6 +14,7 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:

View File

@ -41,6 +41,8 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
if "tps" in params:
self.tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last'] self.drop_last = params['drop_last']
@ -109,7 +111,8 @@ class LMDBReader(object):
norm_img = process_image( norm_img = process_image(
img=img, img=img,
image_shape=self.image_shape, image_shape=self.image_shape,
char_ops=self.char_ops) char_ops=self.char_ops,
tps=self.tps)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()

View File

@ -92,11 +92,16 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
tps=None):
if char_ops.character_type == "en": if char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape) norm_img = resize_norm_img(img, image_shape)
else: else:
norm_img = resize_norm_img_chinese(img, image_shape) if tps:
image_shape = [3, 32, 320]
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()

View File

@ -30,6 +30,7 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
@ -60,8 +61,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
@ -86,6 +87,17 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch":
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set default shape=[3,32,320], it may affect the inference effect"
)
image_shape[-1] = 320
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader

View File

@ -112,7 +112,7 @@ class TextRecognizer(object):
else: else:
preds = rec_idx_batch[rno, 1:end_pos[1]] preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]]) score = np.mean(predict_batch[rno, 1:end_pos[1]])
#todo: why index has 2 offset #attenton index has 2 offset: beg and end
preds = preds - 2 preds = preds - 2
preds_text = self.char_ops.decode(preds) preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score]) rec_res.append([preds_text, score])
@ -138,7 +138,7 @@ if __name__ == "__main__":
except: except:
logger.info( logger.info(
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n" "ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
"Please set --rec_image_shape=input_shape and --rec_char_type='ch' ") "Please set --rec_image_shape=input_shape and --rec_char_type='en' ")
exit() exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))