polish code
This commit is contained in:
parent
fe8ce9afdf
commit
ab0acf78b4
|
@ -98,13 +98,15 @@ class RecModel(object):
|
|||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
],
|
||||
dtype="float32")
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
],
|
||||
dtype="float32")
|
||||
lbl_weight = fluid.layers.data(
|
||||
name="lbl_weight", shape=[-1, 1], dtype='int64')
|
||||
label = fluid.data(
|
||||
|
@ -161,13 +163,15 @@ class RecModel(object):
|
|||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
],
|
||||
dtype="float32")
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
],
|
||||
dtype="float32")
|
||||
feed_list = [
|
||||
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
|
@ -214,7 +218,7 @@ class RecModel(object):
|
|||
if self.loss_type == "ctc":
|
||||
predict = fluid.layers.softmax(predict)
|
||||
if self.loss_type == "srn":
|
||||
logger.infor(
|
||||
raise Exception(
|
||||
"Warning! SRN does not support export model currently")
|
||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||
else:
|
||||
|
|
|
@ -26,12 +26,12 @@ class CharacterOps(object):
|
|||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
if self.loss_type == "srn" and self.character_type == "ch":
|
||||
raise Exception("SRN can only support in character_type == en")
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif self.character_type == "ch":
|
||||
if self.loss_type == "srn":
|
||||
raise Exception("SRN can only support in character_type == en")
|
||||
character_dict_path = config['character_dict_path']
|
||||
add_space = False
|
||||
if 'use_space_char' in config:
|
||||
|
|
Loading…
Reference in New Issue