fix inference in tps
This commit is contained in:
parent
b722eb56c8
commit
be3a164424
|
@ -12,8 +12,10 @@ Global:
|
||||||
test_batch_size_per_card: 256
|
test_batch_size_per_card: 256
|
||||||
image_shape: [3, 32, 100]
|
image_shape: [3, 32, 100]
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
character_type: en
|
character_type: ch
|
||||||
|
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
|
||||||
loss_type: attention
|
loss_type: attention
|
||||||
|
tps: true
|
||||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||||
pretrain_weights:
|
pretrain_weights:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
|
|
|
@ -14,6 +14,7 @@ Global:
|
||||||
max_text_length: 25
|
max_text_length: 25
|
||||||
character_type: en
|
character_type: en
|
||||||
loss_type: ctc
|
loss_type: ctc
|
||||||
|
tps: true
|
||||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||||
pretrain_weights:
|
pretrain_weights:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
|
|
|
@ -41,6 +41,8 @@ class LMDBReader(object):
|
||||||
self.loss_type = params['loss_type']
|
self.loss_type = params['loss_type']
|
||||||
self.max_text_length = params['max_text_length']
|
self.max_text_length = params['max_text_length']
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
|
if "tps" in params:
|
||||||
|
self.tps = True
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
self.drop_last = params['drop_last']
|
self.drop_last = params['drop_last']
|
||||||
|
@ -109,7 +111,8 @@ class LMDBReader(object):
|
||||||
norm_img = process_image(
|
norm_img = process_image(
|
||||||
img=img,
|
img=img,
|
||||||
image_shape=self.image_shape,
|
image_shape=self.image_shape,
|
||||||
char_ops=self.char_ops)
|
char_ops=self.char_ops,
|
||||||
|
tps=self.tps)
|
||||||
yield norm_img
|
yield norm_img
|
||||||
else:
|
else:
|
||||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||||
|
|
|
@ -92,9 +92,14 @@ def process_image(img,
|
||||||
label=None,
|
label=None,
|
||||||
char_ops=None,
|
char_ops=None,
|
||||||
loss_type=None,
|
loss_type=None,
|
||||||
max_text_length=None):
|
max_text_length=None,
|
||||||
|
tps=None):
|
||||||
if char_ops.character_type == "en":
|
if char_ops.character_type == "en":
|
||||||
norm_img = resize_norm_img(img, image_shape)
|
norm_img = resize_norm_img(img, image_shape)
|
||||||
|
else:
|
||||||
|
if tps:
|
||||||
|
image_shape = [3, 32, 320]
|
||||||
|
norm_img = resize_norm_img(img, image_shape)
|
||||||
else:
|
else:
|
||||||
norm_img = resize_norm_img_chinese(img, image_shape)
|
norm_img = resize_norm_img_chinese(img, image_shape)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
|
|
@ -30,6 +30,7 @@ class RecModel(object):
|
||||||
global_params = params['Global']
|
global_params = params['Global']
|
||||||
char_num = global_params['char_ops'].get_char_num()
|
char_num = global_params['char_ops'].get_char_num()
|
||||||
global_params['char_num'] = char_num
|
global_params['char_num'] = char_num
|
||||||
|
self.char_type = global_params['character_type']
|
||||||
if "TPS" in params:
|
if "TPS" in params:
|
||||||
tps_params = deepcopy(params["TPS"])
|
tps_params = deepcopy(params["TPS"])
|
||||||
tps_params.update(global_params)
|
tps_params.update(global_params)
|
||||||
|
@ -60,8 +61,8 @@ class RecModel(object):
|
||||||
def create_feed(self, mode):
|
def create_feed(self, mode):
|
||||||
image_shape = deepcopy(self.image_shape)
|
image_shape = deepcopy(self.image_shape)
|
||||||
image_shape.insert(0, -1)
|
image_shape.insert(0, -1)
|
||||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
|
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||||
if self.loss_type == "attention":
|
if self.loss_type == "attention":
|
||||||
label_in = fluid.data(
|
label_in = fluid.data(
|
||||||
name='label_in',
|
name='label_in',
|
||||||
|
@ -86,6 +87,17 @@ class RecModel(object):
|
||||||
use_double_buffer=True,
|
use_double_buffer=True,
|
||||||
iterable=False)
|
iterable=False)
|
||||||
else:
|
else:
|
||||||
|
if self.char_type == "ch":
|
||||||
|
image_shape[-1] = -1
|
||||||
|
if self.tps != None:
|
||||||
|
logger.info(
|
||||||
|
"WARNRNG!!!\n"
|
||||||
|
"TPS does not support variable shape in chinese!"
|
||||||
|
"We set default shape=[3,32,320], it may affect the inference effect"
|
||||||
|
)
|
||||||
|
image_shape[-1] = 320
|
||||||
|
image = fluid.data(
|
||||||
|
name='image', shape=image_shape, dtype='float32')
|
||||||
labels = None
|
labels = None
|
||||||
loader = None
|
loader = None
|
||||||
return image, labels, loader
|
return image, labels, loader
|
||||||
|
|
|
@ -112,7 +112,7 @@ class TextRecognizer(object):
|
||||||
else:
|
else:
|
||||||
preds = rec_idx_batch[rno, 1:end_pos[1]]
|
preds = rec_idx_batch[rno, 1:end_pos[1]]
|
||||||
score = np.mean(predict_batch[rno, 1:end_pos[1]])
|
score = np.mean(predict_batch[rno, 1:end_pos[1]])
|
||||||
#todo: why index has 2 offset
|
#attenton index has 2 offset: beg and end
|
||||||
preds = preds - 2
|
preds = preds - 2
|
||||||
preds_text = self.char_ops.decode(preds)
|
preds_text = self.char_ops.decode(preds)
|
||||||
rec_res.append([preds_text, score])
|
rec_res.append([preds_text, score])
|
||||||
|
@ -138,7 +138,7 @@ if __name__ == "__main__":
|
||||||
except:
|
except:
|
||||||
logger.info(
|
logger.info(
|
||||||
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
|
"ERROR!! \nInput image shape is not equal with config. TPS does not support variable shape.\n"
|
||||||
"Please set --rec_image_shape=input_shape and --rec_char_type='ch' ")
|
"Please set --rec_image_shape=input_shape and --rec_char_type='en' ")
|
||||||
exit()
|
exit()
|
||||||
for ino in range(len(img_list)):
|
for ino in range(len(img_list)):
|
||||||
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
|
||||||
|
|
Loading…
Reference in New Issue