fix eval
This commit is contained in:
parent
be3a164424
commit
80c188785c
|
@ -41,6 +41,7 @@ class LMDBReader(object):
|
|||
self.loss_type = params['loss_type']
|
||||
self.max_text_length = params['max_text_length']
|
||||
self.mode = params['mode']
|
||||
self.drop_last = False
|
||||
if "tps" in params:
|
||||
self.tps = True
|
||||
if params['mode'] == 'train':
|
||||
|
@ -112,7 +113,8 @@ class LMDBReader(object):
|
|||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.tps)
|
||||
tps=self.tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
|
@ -132,9 +134,13 @@ class LMDBReader(object):
|
|||
if sample_info is None:
|
||||
continue
|
||||
img, label = sample_info
|
||||
outs = process_image(img, self.image_shape, label,
|
||||
self.char_ops, self.loss_type,
|
||||
self.max_text_length)
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
@ -154,7 +160,7 @@ class LMDBReader(object):
|
|||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
if self.mode != 'train' and self.infer_img is None:
|
||||
if self.infer_img is None:
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
||||
|
||||
|
@ -174,6 +180,7 @@ class SimpleReader(object):
|
|||
self.max_text_length = params['max_text_length']
|
||||
self.mode = params['mode']
|
||||
self.infer_img = params['infer_img']
|
||||
self.drop_last = False
|
||||
if params['mode'] == 'train':
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
self.drop_last = params['drop_last']
|
||||
|
|
|
@ -93,11 +93,12 @@ def process_image(img,
|
|||
char_ops=None,
|
||||
loss_type=None,
|
||||
max_text_length=None,
|
||||
tps=None):
|
||||
if char_ops.character_type == "en":
|
||||
tps=None,
|
||||
infer_mode=False):
|
||||
if not infer_mode or char_ops.character_type == "en":
|
||||
norm_img = resize_norm_img(img, image_shape)
|
||||
else:
|
||||
if tps:
|
||||
if tps != None and char_ops.character_type == "ch":
|
||||
image_shape = [3, 32, 320]
|
||||
norm_img = resize_norm_img(img, image_shape)
|
||||
else:
|
||||
|
|
|
@ -31,6 +31,7 @@ class RecModel(object):
|
|||
char_num = global_params['char_ops'].get_char_num()
|
||||
global_params['char_num'] = char_num
|
||||
self.char_type = global_params['character_type']
|
||||
self.infer_img = global_params['infer_img']
|
||||
if "TPS" in params:
|
||||
tps_params = deepcopy(params["TPS"])
|
||||
tps_params.update(global_params)
|
||||
|
@ -87,7 +88,7 @@ class RecModel(object):
|
|||
use_double_buffer=True,
|
||||
iterable=False)
|
||||
else:
|
||||
if self.char_type == "ch":
|
||||
if self.char_type == "ch" and self.infer_img:
|
||||
image_shape[-1] = -1
|
||||
if self.tps != None:
|
||||
logger.info(
|
||||
|
@ -96,8 +97,7 @@ class RecModel(object):
|
|||
"We set default shape=[3,32,320], it may affect the inference effect"
|
||||
)
|
||||
image_shape[-1] = 320
|
||||
image = fluid.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
labels = None
|
||||
loader = None
|
||||
return image, labels, loader
|
||||
|
|
Loading…
Reference in New Issue