This commit is contained in:
tink2123 2020-06-03 17:09:14 +08:00
parent be3a164424
commit 80c188785c
3 changed files with 19 additions and 11 deletions

View File

@ -41,6 +41,7 @@ class LMDBReader(object):
self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length']
self.mode = params['mode']
self.drop_last = False
if "tps" in params:
self.tps = True
if params['mode'] == 'train':
@ -112,7 +113,8 @@ class LMDBReader(object):
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.tps)
tps=self.tps,
infer_mode=True)
yield norm_img
else:
lmdb_sets = self.load_hierarchical_lmdb_dataset()
@ -132,9 +134,13 @@ class LMDBReader(object):
if sample_info is None:
continue
img, label = sample_info
outs = process_image(img, self.image_shape, label,
self.char_ops, self.loss_type,
self.max_text_length)
outs = process_image(
img=img,
image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None:
continue
yield outs
@ -154,7 +160,7 @@ class LMDBReader(object):
if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'train' and self.infer_img is None:
if self.infer_img is None:
return batch_iter_reader
return sample_iter_reader
@ -174,6 +180,7 @@ class SimpleReader(object):
self.max_text_length = params['max_text_length']
self.mode = params['mode']
self.infer_img = params['infer_img']
self.drop_last = False
if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card']
self.drop_last = params['drop_last']

View File

@ -93,11 +93,12 @@ def process_image(img,
char_ops=None,
loss_type=None,
max_text_length=None,
tps=None):
if char_ops.character_type == "en":
tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en":
norm_img = resize_norm_img(img, image_shape)
else:
if tps:
if tps != None and char_ops.character_type == "ch":
image_shape = [3, 32, 320]
norm_img = resize_norm_img(img, image_shape)
else:

View File

@ -31,6 +31,7 @@ class RecModel(object):
char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params:
tps_params = deepcopy(params["TPS"])
tps_params.update(global_params)
@ -87,7 +88,7 @@ class RecModel(object):
use_double_buffer=True,
iterable=False)
else:
if self.char_type == "ch":
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
@ -96,8 +97,7 @@ class RecModel(object):
"We set default shape=[3,32,320], it may affect the inference effect"
)
image_shape[-1] = 320
image = fluid.data(
name='image', shape=image_shape, dtype='float32')
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None
loader = None
return image, labels, loader