polish code

This commit is contained in:
tink2123 2020-08-16 17:09:17 +08:00
parent fe8ce9afdf
commit ab0acf78b4
2 changed files with 11 additions and 7 deletions

View File

@ -98,13 +98,15 @@ class RecModel(object):
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
lbl_weight = fluid.layers.data(
name="lbl_weight", shape=[-1, 1], dtype='int64')
label = fluid.data(
@ -161,13 +163,15 @@ class RecModel(object):
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
gsrm_slf_attn_bias2 = fluid.data(
name="gsrm_slf_attn_bias2",
shape=[
-1, self.num_heads, self.max_text_length,
self.max_text_length
])
],
dtype="float32")
feed_list = [
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2
@ -214,7 +218,7 @@ class RecModel(object):
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
if self.loss_type == "srn":
logger.infor(
raise Exception(
"Warning! SRN does not support export model currently")
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else:

View File

@ -26,12 +26,12 @@ class CharacterOps(object):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
if self.loss_type == "srn" and self.character_type == "ch":
raise Exception("SRN can only support in character_type == en")
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
elif self.character_type == "ch":
if self.loss_type == "srn":
raise Exception("SRN can only support in character_type == en")
character_dict_path = config['character_dict_path']
add_space = False
if 'use_space_char' in config: