Adaptation of Chinese and r34/18
This commit is contained in:
parent
7b201a3855
commit
1b19050391
|
@ -214,6 +214,8 @@ class SimpleReader(object):
|
||||||
self.mode = params['mode']
|
self.mode = params['mode']
|
||||||
self.infer_img = params['infer_img']
|
self.infer_img = params['infer_img']
|
||||||
self.use_tps = False
|
self.use_tps = False
|
||||||
|
if "num_heads" in params:
|
||||||
|
self.num_heads = params['num_heads']
|
||||||
if "tps" in params:
|
if "tps" in params:
|
||||||
self.use_tps = True
|
self.use_tps = True
|
||||||
self.use_distort = False
|
self.use_distort = False
|
||||||
|
@ -251,12 +253,19 @@ class SimpleReader(object):
|
||||||
img = cv2.imread(single_img)
|
img = cv2.imread(single_img)
|
||||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
norm_img = process_image(
|
if self.loss_type == 'srn':
|
||||||
img=img,
|
norm_img = process_image_srn(
|
||||||
image_shape=self.image_shape,
|
img=img,
|
||||||
char_ops=self.char_ops,
|
image_shape=self.image_shape,
|
||||||
tps=self.use_tps,
|
num_heads=self.num_heads,
|
||||||
infer_mode=True)
|
max_text_length=self.max_text_length)
|
||||||
|
else:
|
||||||
|
norm_img = process_image(
|
||||||
|
img=img,
|
||||||
|
image_shape=self.image_shape,
|
||||||
|
char_ops=self.char_ops,
|
||||||
|
tps=self.use_tps,
|
||||||
|
infer_mode=True)
|
||||||
yield norm_img
|
yield norm_img
|
||||||
else:
|
else:
|
||||||
with open(self.label_file_path, "rb") as fin:
|
with open(self.label_file_path, "rb") as fin:
|
||||||
|
@ -286,14 +295,25 @@ class SimpleReader(object):
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||||
|
|
||||||
label = substr[1]
|
label = substr[1]
|
||||||
outs = process_image(
|
if self.loss_type == "srn":
|
||||||
img=img,
|
outs = process_image_srn(
|
||||||
image_shape=self.image_shape,
|
img=img,
|
||||||
label=label,
|
image_shape=self.image_shape,
|
||||||
char_ops=self.char_ops,
|
num_heads=self.num_heads,
|
||||||
loss_type=self.loss_type,
|
max_text_length=self.max_text_length,
|
||||||
max_text_length=self.max_text_length,
|
label=label,
|
||||||
distort=self.use_distort)
|
char_ops=self.char_ops,
|
||||||
|
loss_type=self.loss_type)
|
||||||
|
|
||||||
|
else:
|
||||||
|
outs = process_image(
|
||||||
|
img=img,
|
||||||
|
image_shape=self.image_shape,
|
||||||
|
label=label,
|
||||||
|
char_ops=self.char_ops,
|
||||||
|
loss_type=self.loss_type,
|
||||||
|
max_text_length=self.max_text_length,
|
||||||
|
distort=self.use_distort)
|
||||||
if outs is None:
|
if outs is None:
|
||||||
continue
|
continue
|
||||||
yield outs
|
yield outs
|
||||||
|
|
|
@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape):
|
||||||
|
|
||||||
def srn_other_inputs(image_shape,
|
def srn_other_inputs(image_shape,
|
||||||
num_heads,
|
num_heads,
|
||||||
max_text_length):
|
max_text_length,
|
||||||
|
char_num):
|
||||||
|
|
||||||
imgC, imgH, imgW = image_shape
|
imgC, imgH, imgW = image_shape
|
||||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||||
|
@ -418,7 +419,7 @@ def srn_other_inputs(image_shape,
|
||||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
||||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
||||||
|
|
||||||
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
|
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
|
||||||
|
|
||||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
||||||
|
@ -441,17 +442,18 @@ def process_image_srn(img,
|
||||||
loss_type=None):
|
loss_type=None):
|
||||||
norm_img = resize_norm_img_srn(img, image_shape)
|
norm_img = resize_norm_img_srn(img, image_shape)
|
||||||
norm_img = norm_img[np.newaxis, :]
|
norm_img = norm_img[np.newaxis, :]
|
||||||
|
char_num = char_ops.get_char_num()
|
||||||
|
|
||||||
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||||
srn_other_inputs(image_shape, num_heads, max_text_length)
|
srn_other_inputs(image_shape, num_heads, max_text_length,char_num)
|
||||||
|
|
||||||
if label is not None:
|
if label is not None:
|
||||||
char_num = char_ops.get_char_num()
|
|
||||||
text = char_ops.encode(label)
|
text = char_ops.encode(label)
|
||||||
if len(text) == 0 or len(text) > max_text_length:
|
if len(text) == 0 or len(text) > max_text_length:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
if loss_type == "srn":
|
if loss_type == "srn":
|
||||||
text_padded = [37] * max_text_length
|
text_padded = [int(char_num-1)] * max_text_length
|
||||||
for i in range(len(text)):
|
for i in range(len(text)):
|
||||||
text_padded[i] = text[i]
|
text_padded[i] = text[i]
|
||||||
lbl_weight[i] = [1.0]
|
lbl_weight[i] = [1.0]
|
||||||
|
|
|
@ -81,6 +81,23 @@ class ResNet():
|
||||||
num_filters=num_filters[block],
|
num_filters=num_filters[block],
|
||||||
stride=stride_list[block] if i == 0 else 1, name=conv_name)
|
stride=stride_list[block] if i == 0 else 1, name=conv_name)
|
||||||
F.append(conv)
|
F.append(conv)
|
||||||
|
else:
|
||||||
|
for block in range(len(depth)):
|
||||||
|
for i in range(depth[block]):
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
|
||||||
|
if i == 0 and block != 0:
|
||||||
|
stride = (2, 1)
|
||||||
|
else:
|
||||||
|
stride = (1, 1)
|
||||||
|
|
||||||
|
conv = self.basic_block(
|
||||||
|
input=conv,
|
||||||
|
num_filters=num_filters[block],
|
||||||
|
stride=stride,
|
||||||
|
if_first=block == i == 0,
|
||||||
|
name=conv_name)
|
||||||
|
F.append(conv)
|
||||||
|
|
||||||
base = F[-1]
|
base = F[-1]
|
||||||
for i in [-2, -3]:
|
for i in [-2, -3]:
|
||||||
|
|
|
@ -26,8 +26,6 @@ class CharacterOps(object):
|
||||||
self.character_type = config['character_type']
|
self.character_type = config['character_type']
|
||||||
self.loss_type = config['loss_type']
|
self.loss_type = config['loss_type']
|
||||||
self.max_text_len = config['max_text_length']
|
self.max_text_len = config['max_text_length']
|
||||||
if self.loss_type == "srn" and self.character_type != "en":
|
|
||||||
raise Exception("SRN can only support in character_type == en")
|
|
||||||
if self.character_type == "en":
|
if self.character_type == "en":
|
||||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||||
dict_character = list(self.character_str)
|
dict_character = list(self.character_str)
|
||||||
|
@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops,
|
||||||
acc_num = 0
|
acc_num = 0
|
||||||
img_num = 0
|
img_num = 0
|
||||||
|
|
||||||
|
char_num = char_ops.get_char_num()
|
||||||
|
|
||||||
total_len = preds.shape[0]
|
total_len = preds.shape[0]
|
||||||
img_num = int(total_len / max_text_len)
|
img_num = int(total_len / max_text_len)
|
||||||
for i in range(img_num):
|
for i in range(img_num):
|
||||||
cur_label = []
|
cur_label = []
|
||||||
cur_pred = []
|
cur_pred = []
|
||||||
for j in range(max_text_len):
|
for j in range(max_text_len):
|
||||||
if labels[j + i * max_text_len] != 37: #0
|
if labels[j + i * max_text_len] != int(char_num-1): #0
|
||||||
cur_label.append(labels[j + i * max_text_len][0])
|
cur_label.append(labels[j + i * max_text_len][0])
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops,
|
||||||
elif j == len(cur_label) and j == max_text_len:
|
elif j == len(cur_label) and j == max_text_len:
|
||||||
acc_num += 1
|
acc_num += 1
|
||||||
break
|
break
|
||||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
|
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1):
|
||||||
acc_num += 1
|
acc_num += 1
|
||||||
break
|
break
|
||||||
acc = acc_num * 1.0 / img_num
|
acc = acc_num * 1.0 / img_num
|
||||||
|
|
|
@ -140,12 +140,12 @@ def main():
|
||||||
preds = preds.reshape(-1)
|
preds = preds.reshape(-1)
|
||||||
preds_text = char_ops.decode(preds)
|
preds_text = char_ops.decode(preds)
|
||||||
elif loss_type == "srn":
|
elif loss_type == "srn":
|
||||||
cur_pred = []
|
char_num = char_ops.get_char_num()
|
||||||
preds = np.array(predict[0])
|
preds = np.array(predict[0])
|
||||||
preds = preds.reshape(-1)
|
preds = preds.reshape(-1)
|
||||||
probs = np.array(predict[1])
|
probs = np.array(predict[1])
|
||||||
ind = np.argmax(probs, axis=1)
|
ind = np.argmax(probs, axis=1)
|
||||||
valid_ind = np.where(preds != 37)[0]
|
valid_ind = np.where(preds != int(char_num-1))[0]
|
||||||
if len(valid_ind) == 0:
|
if len(valid_ind) == 0:
|
||||||
continue
|
continue
|
||||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||||
|
|
Loading…
Reference in New Issue