This commit is contained in:
tink2123 2020-06-03 17:09:14 +08:00
parent be3a164424
commit 80c188785c
3 changed files with 19 additions and 11 deletions

View File

@ -41,6 +41,7 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.drop_last = False
if "tps" in params: if "tps" in params:
self.tps = True self.tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
@ -112,7 +113,8 @@ class LMDBReader(object):
img=img, img=img,
image_shape=self.image_shape, image_shape=self.image_shape,
char_ops=self.char_ops, char_ops=self.char_ops,
tps=self.tps) tps=self.tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
@ -132,9 +134,13 @@ class LMDBReader(object):
if sample_info is None: if sample_info is None:
continue continue
img, label = sample_info img, label = sample_info
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img=img,
self.max_text_length) image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
@ -154,7 +160,7 @@ class LMDBReader(object):
if len(batch_outs) != 0: if len(batch_outs) != 0:
yield batch_outs yield batch_outs
if self.mode != 'train' and self.infer_img is None: if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
@ -174,6 +180,7 @@ class SimpleReader(object):
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img'] self.infer_img = params['infer_img']
self.drop_last = False
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last'] self.drop_last = params['drop_last']

View File

@ -93,11 +93,12 @@ def process_image(img,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None, max_text_length=None,
tps=None): tps=None,
if char_ops.character_type == "en": infer_mode=False):
if not infer_mode or char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape) norm_img = resize_norm_img(img, image_shape)
else: else:
if tps: if tps != None and char_ops.character_type == "ch":
image_shape = [3, 32, 320] image_shape = [3, 32, 320]
norm_img = resize_norm_img(img, image_shape) norm_img = resize_norm_img(img, image_shape)
else: else:

View File

@ -31,6 +31,7 @@ class RecModel(object):
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type'] self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
@ -87,7 +88,7 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch": if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1 image_shape[-1] = -1
if self.tps != None: if self.tps != None:
logger.info( logger.info(
@ -96,8 +97,7 @@ class RecModel(object):
"We set default shape=[3,32,320], it may affect the inference effect" "We set default shape=[3,32,320], it may affect the inference effect"
) )
image_shape[-1] = 320 image_shape[-1] = 320
image = fluid.data( image = fluid.data(name='image', shape=image_shape, dtype='float32')
name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader