update infer_rec for srn
This commit is contained in:
parent
6832ca029f
commit
bf4863c950
|
@ -40,10 +40,12 @@ class LMDBReader(object):
|
|||
self.image_shape = params['image_shape']
|
||||
self.loss_type = params['loss_type']
|
||||
self.max_text_length = params['max_text_length']
|
||||
self.num_heads = params['num_heads']
|
||||
self.mode = params['mode']
|
||||
self.drop_last = False
|
||||
self.use_tps = False
|
||||
self.num_heads = None
|
||||
if "num_heads" in params:
|
||||
self.num_heads = params['num_heads']
|
||||
if "tps" in params:
|
||||
self.ues_tps = True
|
||||
self.use_distort = False
|
||||
|
@ -134,20 +136,6 @@ class LMDBReader(object):
|
|||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
#elif self.mode == 'eval':
|
||||
# image_file_list = get_image_file_list(self.infer_img)
|
||||
# for single_img in image_file_list:
|
||||
# img = cv2.imread(single_img)
|
||||
# if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
# norm_img = process_image(
|
||||
# img=img,
|
||||
# image_shape=self.image_shape,
|
||||
# char_ops=self.char_ops,
|
||||
# tps=self.use_tps,
|
||||
# infer_mode=True
|
||||
# )
|
||||
# yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
if process_id == 0:
|
||||
|
@ -169,14 +157,22 @@ class LMDBReader(object):
|
|||
outs = []
|
||||
if self.loss_type == "srn":
|
||||
outs = process_image_srn(
|
||||
img, self.image_shape, self.num_heads,
|
||||
self.max_text_length, label, self.char_ops,
|
||||
self.loss_type)
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type)
|
||||
|
||||
else:
|
||||
outs = process_image(
|
||||
img, self.image_shape, label, self.char_ops,
|
||||
self.loss_type, self.max_text_length)
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
@ -192,6 +188,7 @@ class LMDBReader(object):
|
|||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if not self.drop_last:
|
||||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
|
|
|
@ -58,7 +58,10 @@ class RecModel(object):
|
|||
self.loss_type = global_params['loss_type']
|
||||
self.image_shape = global_params['image_shape']
|
||||
self.max_text_length = global_params['max_text_length']
|
||||
if "num_heads" in params:
|
||||
self.num_heads = global_params["num_heads"]
|
||||
else:
|
||||
self.num_heads = None
|
||||
|
||||
def create_feed(self, mode):
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
|
|
|
@ -32,7 +32,7 @@ class ResNet():
|
|||
def __init__(self, params):
|
||||
self.layers = params['layers']
|
||||
self.is_3x3 = True
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert self.layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
|
||||
|
||||
|
|
|
@ -21,15 +21,12 @@ import math
|
|||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
#from .rec_seq_encoder import SequenceEncoder
|
||||
#from ..common_functions import get_para_bias_attr
|
||||
import numpy as np
|
||||
from .self_attention.model import wrap_encoder
|
||||
from .self_attention.model import wrap_encoder_forFeature
|
||||
gradient_clip = 10
|
||||
|
||||
|
||||
|
||||
class SRNPredict(object):
|
||||
def __init__(self, params):
|
||||
super(SRNPredict, self).__init__()
|
||||
|
@ -41,7 +38,6 @@ class SRNPredict(object):
|
|||
self.num_decoder_TUs = params['num_decoder_TUs']
|
||||
self.hidden_dims = params['hidden_dims']
|
||||
|
||||
|
||||
def pvam(self, inputs, others):
|
||||
|
||||
b, c, h, w = inputs.shape
|
||||
|
@ -53,14 +49,14 @@ class SRNPredict(object):
|
|||
encoder_word_pos = others["encoder_word_pos"]
|
||||
gsrm_word_pos = others["gsrm_word_pos"]
|
||||
|
||||
|
||||
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||
word_features = wrap_encoder_forFeature(src_vocab_size=-1,
|
||||
word_features = wrap_encoder_forFeature(
|
||||
src_vocab_size=-1,
|
||||
max_length=t,
|
||||
n_layer=self.num_encoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key= int(self.hidden_dims / self.num_heads),
|
||||
d_value= int(self.hidden_dims / self.num_heads),
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
|
@ -69,26 +65,35 @@ class SRNPredict(object):
|
|||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs,
|
||||
)
|
||||
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip))
|
||||
enc_inputs=enc_inputs, )
|
||||
fluid.clip.set_gradient_clip(
|
||||
fluid.clip.GradientClipByValue(gradient_clip))
|
||||
|
||||
#===== Parallel Visual Attention Module =====
|
||||
b, t, c = word_features.shape
|
||||
|
||||
word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2)
|
||||
word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c])
|
||||
word_features_ = fluid.layers.expand(word_features_, [1, self.max_length, 1, 1])
|
||||
word_pos_feature = fluid.layers.embedding(gsrm_word_pos, [self.max_length, c])
|
||||
word_pos_ = fluid.layers.reshape(word_pos_feature, [-1, self.max_length, 1, c])
|
||||
word_features_ = fluid.layers.expand(word_features_,
|
||||
[1, self.max_length, 1, 1])
|
||||
word_pos_feature = fluid.layers.embedding(gsrm_word_pos,
|
||||
[self.max_length, c])
|
||||
word_pos_ = fluid.layers.reshape(word_pos_feature,
|
||||
[-1, self.max_length, 1, c])
|
||||
word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1])
|
||||
temp = fluid.layers.elementwise_add(word_features_, word_pos_, act='tanh')
|
||||
temp = fluid.layers.elementwise_add(
|
||||
word_features_, word_pos_, act='tanh')
|
||||
|
||||
attention_weight = fluid.layers.fc(input=temp, size=1, num_flatten_dims=3, bias_attr=False)
|
||||
attention_weight = fluid.layers.reshape(x=attention_weight, shape=[-1, self.max_length, t])
|
||||
attention_weight = fluid.layers.fc(input=temp,
|
||||
size=1,
|
||||
num_flatten_dims=3,
|
||||
bias_attr=False)
|
||||
attention_weight = fluid.layers.reshape(
|
||||
x=attention_weight, shape=[-1, self.max_length, t])
|
||||
attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1)
|
||||
|
||||
pvam_features = fluid.layers.matmul(attention_weight, word_features)#[b, max_length, c]
|
||||
pvam_features = fluid.layers.matmul(attention_weight,
|
||||
word_features) #[b, max_length, c]
|
||||
|
||||
return pvam_features
|
||||
|
||||
|
@ -96,7 +101,8 @@ class SRNPredict(object):
|
|||
|
||||
#===== GSRM Visual-to-semantic embedding block =====
|
||||
b, t, c = pvam_features.shape
|
||||
word_out = fluid.layers.fc(input=fluid.layers.reshape(pvam_features, [-1, c]),
|
||||
word_out = fluid.layers.fc(
|
||||
input=fluid.layers.reshape(pvam_features, [-1, c]),
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
#word_out.stop_gradient = True
|
||||
|
@ -120,7 +126,8 @@ class SRNPredict(object):
|
|||
word1 for forward; word2 for backward
|
||||
"""
|
||||
word1 = fluid.layers.cast(word_ids, "float32")
|
||||
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], pad_value=1.0 * pad_idx)
|
||||
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0],
|
||||
pad_value=1.0 * pad_idx)
|
||||
word1 = fluid.layers.cast(word1, "int64")
|
||||
word1 = word1[:, :-1, :]
|
||||
word2 = word_ids
|
||||
|
@ -132,7 +139,8 @@ class SRNPredict(object):
|
|||
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||
|
||||
gsrm_feature1 = wrap_encoder(src_vocab_size=self.char_num + 1,
|
||||
gsrm_feature1 = wrap_encoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
|
@ -146,9 +154,9 @@ class SRNPredict(object):
|
|||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_1,
|
||||
)
|
||||
gsrm_feature2 = wrap_encoder(src_vocab_size=self.char_num + 1,
|
||||
enc_inputs=enc_inputs_1, )
|
||||
gsrm_feature2 = wrap_encoder(
|
||||
src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
|
@ -162,9 +170,9 @@ class SRNPredict(object):
|
|||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_2,
|
||||
)
|
||||
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], pad_value=0.)
|
||||
enc_inputs=enc_inputs_2, )
|
||||
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0],
|
||||
pad_value=0.)
|
||||
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||
|
||||
|
@ -172,10 +180,12 @@ class SRNPredict(object):
|
|||
|
||||
gsrm_out = fluid.layers.matmul(
|
||||
x=gsrm_features,
|
||||
y=fluid.default_main_program().global_block().var("src_word_emb_table"),
|
||||
y=fluid.default_main_program().global_block().var(
|
||||
"src_word_emb_table"),
|
||||
transpose_y=True)
|
||||
b,t,c = gsrm_out.shape
|
||||
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, [-1, c]))
|
||||
b, t, c = gsrm_out.shape
|
||||
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out,
|
||||
[-1, c]))
|
||||
|
||||
return gsrm_features, word_out, gsrm_out
|
||||
|
||||
|
@ -184,19 +194,25 @@ class SRNPredict(object):
|
|||
#===== Visual-Semantic Fusion Decoder Module =====
|
||||
b, t, c1 = pvam_features.shape
|
||||
b, t, c2 = gsrm_features.shape
|
||||
combine_features_ = fluid.layers.concat([pvam_features, gsrm_features], axis=2)
|
||||
img_comb_features_ = fluid.layers.reshape(x=combine_features_, shape=[-1, c1 + c2])
|
||||
img_comb_features_map = fluid.layers.fc(input=img_comb_features_, size=c1, act="sigmoid")
|
||||
img_comb_features_map = fluid.layers.reshape(x=img_comb_features_map, shape=[-1, t, c1])
|
||||
combine_features = img_comb_features_map * pvam_features + (1.0 - img_comb_features_map) * gsrm_features
|
||||
img_comb_features = fluid.layers.reshape(x=combine_features, shape=[-1, c1])
|
||||
combine_features_ = fluid.layers.concat(
|
||||
[pvam_features, gsrm_features], axis=2)
|
||||
img_comb_features_ = fluid.layers.reshape(
|
||||
x=combine_features_, shape=[-1, c1 + c2])
|
||||
img_comb_features_map = fluid.layers.fc(input=img_comb_features_,
|
||||
size=c1,
|
||||
act="sigmoid")
|
||||
img_comb_features_map = fluid.layers.reshape(
|
||||
x=img_comb_features_map, shape=[-1, t, c1])
|
||||
combine_features = img_comb_features_map * pvam_features + (
|
||||
1.0 - img_comb_features_map) * gsrm_features
|
||||
img_comb_features = fluid.layers.reshape(
|
||||
x=combine_features, shape=[-1, c1])
|
||||
|
||||
fc_out = fluid.layers.fc(input=img_comb_features,
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
return fc_out
|
||||
|
||||
|
||||
def __call__(self, inputs, others, mode=None):
|
||||
|
||||
pvam_features = self.pvam(inputs, others)
|
||||
|
@ -204,15 +220,11 @@ class SRNPredict(object):
|
|||
final_out = self.vsfd(pvam_features, gsrm_features)
|
||||
|
||||
_, decoded_out = fluid.layers.topk(input=final_out, k=1)
|
||||
predicts = {'predict': final_out, 'decoded_out': decoded_out,
|
||||
'word_out': word_out, 'gsrm_out': gsrm_out}
|
||||
predicts = {
|
||||
'predict': final_out,
|
||||
'decoded_out': decoded_out,
|
||||
'word_out': word_out,
|
||||
'gsrm_out': gsrm_out
|
||||
}
|
||||
|
||||
return predicts
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -35,24 +35,21 @@ class SRNLoss(object):
|
|||
lbl_weight = others['lbl_weight']
|
||||
|
||||
casted_label = fluid.layers.cast(x=label, dtype='int64')
|
||||
cost_word = fluid.layers.cross_entropy(input=word_predict, label=casted_label)
|
||||
cost_gsrm = fluid.layers.cross_entropy(input=gsrm_predict, label=casted_label)
|
||||
cost_vsfd = fluid.layers.cross_entropy(input=predict, label=casted_label)
|
||||
cost_word = fluid.layers.cross_entropy(
|
||||
input=word_predict, label=casted_label)
|
||||
cost_gsrm = fluid.layers.cross_entropy(
|
||||
input=gsrm_predict, label=casted_label)
|
||||
cost_vsfd = fluid.layers.cross_entropy(
|
||||
input=predict, label=casted_label)
|
||||
|
||||
#cost_word = cost_word * lbl_weight
|
||||
#cost_gsrm = cost_gsrm * lbl_weight
|
||||
#cost_vsfd = cost_vsfd * lbl_weight
|
||||
cost_word = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_word), shape=[1])
|
||||
cost_gsrm = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = fluid.layers.reshape(
|
||||
x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
|
||||
|
||||
cost_word = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_word), shape=[1])
|
||||
cost_gsrm = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
|
||||
sum_cost = fluid.layers.sum(
|
||||
[cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
|
||||
|
||||
sum_cost = fluid.layers.sum([cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
|
||||
|
||||
#sum_cost = fluid.layers.sum([cost_word * 3.0, cost_vsfd, cost_gsrm * 0.15])
|
||||
#sum_cost = cost_word
|
||||
|
||||
#fluid.layers.Print(cost_word,message="word_cost")
|
||||
#fluid.layers.Print(cost_vsfd,message="img_cost")
|
||||
return [sum_cost,cost_vsfd,cost_word]
|
||||
#return [sum_cost, cost_vsfd, cost_word]
|
||||
return [sum_cost, cost_vsfd, cost_word]
|
||||
|
|
|
@ -149,6 +149,7 @@ def cal_predicts_accuracy(char_ops,
|
|||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def cal_predicts_accuracy_srn(char_ops,
|
||||
preds,
|
||||
labels,
|
||||
|
@ -159,7 +160,6 @@ def cal_predicts_accuracy_srn(char_ops,
|
|||
|
||||
total_len = preds.shape[0]
|
||||
img_num = int(total_len / max_text_len)
|
||||
#print (img_num)
|
||||
for i in range(img_num):
|
||||
cur_label = []
|
||||
cur_pred = []
|
||||
|
@ -169,18 +169,9 @@ def cal_predicts_accuracy_srn(char_ops,
|
|||
else:
|
||||
break
|
||||
|
||||
if is_debug:
|
||||
for j in range(max_text_len):
|
||||
if preds[j + i * max_text_len] != 37: #0
|
||||
cur_pred.append(preds[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
print ("cur_label: ", cur_label)
|
||||
print ("cur_pred: ", cur_pred)
|
||||
|
||||
|
||||
for j in range(max_text_len + 1):
|
||||
if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]:
|
||||
if j < len(cur_label) and preds[j + i * max_text_len][
|
||||
0] != cur_label[j]:
|
||||
break
|
||||
elif j == len(cur_label) and j == max_text_len:
|
||||
acc_num += 1
|
||||
|
|
|
@ -123,8 +123,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
|
||||
def test_rec_benchmark(exe, config, eval_info_dict):
|
||||
" Evaluate lmdb dataset "
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', \
|
||||
'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860','IC03_867', \
|
||||
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
|
||||
eval_data_dir = config['TestReader']['lmdb_sets_dir']
|
||||
total_evaluation_data_number = 0
|
||||
total_correct_number = 0
|
||||
|
|
|
@ -64,7 +64,6 @@ def main():
|
|||
exe = fluid.Executor(place)
|
||||
|
||||
rec_model = create_module(config['Architecture']['function'])(params=config)
|
||||
|
||||
startup_prog = fluid.Program()
|
||||
eval_prog = fluid.Program()
|
||||
with fluid.program_guard(eval_prog, startup_prog):
|
||||
|
@ -86,10 +85,36 @@ def main():
|
|||
for i in range(max_img_num):
|
||||
logger.info("infer_img:%s" % infer_list[i])
|
||||
img = next(blobs)
|
||||
if loss_type != "srn":
|
||||
predict = exe.run(program=eval_prog,
|
||||
feed={"image": img},
|
||||
fetch_list=fetch_varname_list,
|
||||
return_numpy=False)
|
||||
else:
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
encoder_word_pos_list.append(img[1])
|
||||
gsrm_word_pos_list.append(img[2])
|
||||
gsrm_slf_attn_bias1_list.append(img[3])
|
||||
gsrm_slf_attn_bias2_list.append(img[4])
|
||||
|
||||
encoder_word_pos_list = np.concatenate(
|
||||
encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(
|
||||
gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
|
||||
predict = exe.run(program=eval_prog, \
|
||||
feed={'image': img[0], 'encoder_word_pos': encoder_word_pos_list,
|
||||
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
|
||||
fetch_list=fetch_varname_list, \
|
||||
return_numpy=False)
|
||||
if loss_type == "ctc":
|
||||
preds = np.array(predict[0])
|
||||
preds = preds.reshape(-1)
|
||||
|
@ -114,7 +139,18 @@ def main():
|
|||
score = np.mean(probs[0, 1:end_pos[1]])
|
||||
preds = preds.reshape(-1)
|
||||
preds_text = char_ops.decode(preds)
|
||||
|
||||
elif loss_type == "srn":
|
||||
cur_pred = []
|
||||
preds = np.array(predict[0])
|
||||
preds = preds.reshape(-1)
|
||||
probs = np.array(predict[1])
|
||||
ind = np.argmax(probs, axis=1)
|
||||
valid_ind = np.where(preds != 37)[0]
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
preds = preds[:valid_ind[-1] + 1]
|
||||
preds_text = char_ops.decode(preds)
|
||||
logger.info("\t index: {}".format(preds))
|
||||
logger.info("\t word : {}".format(preds_text))
|
||||
logger.info("\t score: {}".format(score))
|
||||
|
|
Loading…
Reference in New Issue