fix eval
This commit is contained in:
parent
be3a164424
commit
80c188785c
|
@ -41,6 +41,7 @@ class LMDBReader(object):
|
||||||
self.loss_type = params['loss_type']
|
self.loss_type = params['loss_type']
|
||||||
self.max_text_length = params['max_text_length']
|
self.max_text_length = params['max_text_length']
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
|
self.drop_last = False
|
||||||
if "tps" in params:
|
if "tps" in params:
|
||||||
self.tps = True
|
self.tps = True
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
|
@ -112,7 +113,8 @@ class LMDBReader(object):
|
||||||
img=img,
|
img=img,
|
||||||
image_shape=self.image_shape,
|
image_shape=self.image_shape,
|
||||||
char_ops=self.char_ops,
|
char_ops=self.char_ops,
|
||||||
tps=self.tps)
|
tps=self.tps,
|
||||||
|
infer_mode=True)
|
||||||
yield norm_img
|
yield norm_img
|
||||||
else:
|
else:
|
||||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||||
|
@ -132,9 +134,13 @@ class LMDBReader(object):
|
||||||
if sample_info is None:
|
if sample_info is None:
|
||||||
continue
|
continue
|
||||||
img, label = sample_info
|
img, label = sample_info
|
||||||
outs = process_image(img, self.image_shape, label,
|
outs = process_image(
|
||||||
self.char_ops, self.loss_type,
|
img=img,
|
||||||
self.max_text_length)
|
image_shape=self.image_shape,
|
||||||
|
label=label,
|
||||||
|
char_ops=self.char_ops,
|
||||||
|
loss_type=self.loss_type,
|
||||||
|
max_text_length=self.max_text_length)
|
||||||
if outs is None:
|
if outs is None:
|
||||||
continue
|
continue
|
||||||
yield outs
|
yield outs
|
||||||
|
@ -154,7 +160,7 @@ class LMDBReader(object):
|
||||||
if len(batch_outs) != 0:
|
if len(batch_outs) != 0:
|
||||||
yield batch_outs
|
yield batch_outs
|
||||||
|
|
||||||
if self.mode != 'train' and self.infer_img is None:
|
if self.infer_img is None:
|
||||||
return batch_iter_reader
|
return batch_iter_reader
|
||||||
return sample_iter_reader
|
return sample_iter_reader
|
||||||
|
|
||||||
|
@ -174,6 +180,7 @@ class SimpleReader(object):
|
||||||
self.max_text_length = params['max_text_length']
|
self.max_text_length = params['max_text_length']
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
self.infer_img = params['infer_img']
|
self.infer_img = params['infer_img']
|
||||||
|
self.drop_last = False
|
||||||
if params['mode'] == 'train':
|
if params['mode'] == 'train':
|
||||||
self.batch_size = params['train_batch_size_per_card']
|
self.batch_size = params['train_batch_size_per_card']
|
||||||
self.drop_last = params['drop_last']
|
self.drop_last = params['drop_last']
|
||||||
|
|
|
@ -93,11 +93,12 @@ def process_image(img,
|
||||||
char_ops=None,
|
char_ops=None,
|
||||||
loss_type=None,
|
loss_type=None,
|
||||||
max_text_length=None,
|
max_text_length=None,
|
||||||
tps=None):
|
tps=None,
|
||||||
if char_ops.character_type == "en":
|
infer_mode=False):
|
||||||
|
if not infer_mode or char_ops.character_type == "en":
|
||||||
norm_img = resize_norm_img(img, image_shape)
|
norm_img = resize_norm_img(img, image_shape)
|
||||||
else:
|
else:
|
||||||
if tps:
|
if tps != None and char_ops.character_type == "ch":
|
||||||
image_shape = [3, 32, 320]
|
image_shape = [3, 32, 320]
|
||||||
norm_img = resize_norm_img(img, image_shape)
|
norm_img = resize_norm_img(img, image_shape)
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -31,6 +31,7 @@ class RecModel(object):
|
||||||
char_num = global_params['char_ops'].get_char_num()
|
char_num = global_params['char_ops'].get_char_num()
|
||||||
global_params['char_num'] = char_num
|
global_params['char_num'] = char_num
|
||||||
self.char_type = global_params['character_type']
|
self.char_type = global_params['character_type']
|
||||||
|
self.infer_img = global_params['infer_img']
|
||||||
if "TPS" in params:
|
if "TPS" in params:
|
||||||
tps_params = deepcopy(params["TPS"])
|
tps_params = deepcopy(params["TPS"])
|
||||||
tps_params.update(global_params)
|
tps_params.update(global_params)
|
||||||
|
@ -87,7 +88,7 @@ class RecModel(object):
|
||||||
use_double_buffer=True,
|
use_double_buffer=True,
|
||||||
iterable=False)
|
iterable=False)
|
||||||
else:
|
else:
|
||||||
if self.char_type == "ch":
|
if self.char_type == "ch" and self.infer_img:
|
||||||
image_shape[-1] = -1
|
image_shape[-1] = -1
|
||||||
if self.tps != None:
|
if self.tps != None:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -96,8 +97,7 @@ class RecModel(object):
|
||||||
"We set default shape=[3,32,320], it may affect the inference effect"
|
"We set default shape=[3,32,320], it may affect the inference effect"
|
||||||
)
|
)
|
||||||
image_shape[-1] = 320
|
image_shape[-1] = 320
|
||||||
image = fluid.data(
|
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||||
name='image', shape=image_shape, dtype='float32')
|
|
||||||
labels = None
|
labels = None
|
||||||
loader = None
|
loader = None
|
||||||
return image, labels, loader
|
return image, labels, loader
|
||||||
|
|
Loading…
Reference in New Issue