Merge pull request #609 from tink2123/adaptation_ch
Adaptation chinese for SRN
This commit is contained in:
commit
2bdaea5656
|
@ -27,7 +27,7 @@ Architecture:
|
|||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
|
||||
function: ppocr.modeling.backbones.rec_resnet_fpn,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
|
|
|
@ -45,7 +45,7 @@ At present, the open source model, dataset and magnitude are as follows:
|
|||
Among them, the public datasets are opensourced, users can search and download by themselves, or refer to [Chinese data set](./datasets_en.md), synthetic data is not opensourced, users can use open-source synthesis tools to synthesize data themselves. Current available synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator), etc.
|
||||
|
||||
10. **Error in using the model with TPS module for prediction**
|
||||
Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3](108) != Grid dimension[2](100)
|
||||
Error message: Input(X) dims[3] and Input(Grid) dims[2] should be equal, but received X dimension[3]\(108) != Grid dimension[2]\(100)
|
||||
Solution:TPS does not support variable shape. Please set --rec_image_shape='3,32,100' and --rec_char_type='en'
|
||||
|
||||
11. **Custom dictionary used during training, the recognition results show that words do not appear in the dictionary**
|
||||
|
|
|
@ -214,6 +214,8 @@ class SimpleReader(object):
|
|||
self.mode = params['mode']
|
||||
self.infer_img = params['infer_img']
|
||||
self.use_tps = False
|
||||
if "num_heads" in params:
|
||||
self.num_heads = params['num_heads']
|
||||
if "tps" in params:
|
||||
self.use_tps = True
|
||||
self.use_distort = False
|
||||
|
@ -251,12 +253,19 @@ class SimpleReader(object):
|
|||
img = cv2.imread(single_img)
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
if self.loss_type == 'srn':
|
||||
norm_img = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length)
|
||||
else:
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
else:
|
||||
with open(self.label_file_path, "rb") as fin:
|
||||
|
@ -286,14 +295,25 @@ class SimpleReader(object):
|
|||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
label = substr[1]
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length,
|
||||
distort=self.use_distort)
|
||||
if self.loss_type == "srn":
|
||||
outs = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type)
|
||||
|
||||
else:
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length,
|
||||
distort=self.use_distort)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
|
|
@ -410,7 +410,8 @@ def resize_norm_img_srn(img, image_shape):
|
|||
|
||||
def srn_other_inputs(image_shape,
|
||||
num_heads,
|
||||
max_text_length):
|
||||
max_text_length,
|
||||
char_num):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
@ -418,7 +419,7 @@ def srn_other_inputs(image_shape,
|
|||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
||||
|
||||
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
|
||||
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
|
@ -441,17 +442,18 @@ def process_image_srn(img,
|
|||
loss_type=None):
|
||||
norm_img = resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
char_num = char_ops.get_char_num()
|
||||
|
||||
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
srn_other_inputs(image_shape, num_heads, max_text_length,char_num)
|
||||
|
||||
if label is not None:
|
||||
char_num = char_ops.get_char_num()
|
||||
text = char_ops.encode(label)
|
||||
if len(text) == 0 or len(text) > max_text_length:
|
||||
return None
|
||||
else:
|
||||
if loss_type == "srn":
|
||||
text_padded = [37] * max_text_length
|
||||
text_padded = [int(char_num-1)] * max_text_length
|
||||
for i in range(len(text)):
|
||||
text_padded[i] = text[i]
|
||||
lbl_weight[i] = [1.0]
|
||||
|
|
|
@ -22,12 +22,12 @@ import paddle
|
|||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
|
||||
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
|
||||
__all__ = [
|
||||
"ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"
|
||||
]
|
||||
|
||||
Trainable = True
|
||||
w_nolr = fluid.ParamAttr(
|
||||
trainable = Trainable)
|
||||
w_nolr = fluid.ParamAttr(trainable=Trainable)
|
||||
train_parameters = {
|
||||
"input_size": [3, 224, 224],
|
||||
"input_mean": [0.485, 0.456, 0.406],
|
||||
|
@ -40,12 +40,12 @@ train_parameters = {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
class ResNet():
|
||||
def __init__(self, params):
|
||||
self.layers = params['layers']
|
||||
self.params = train_parameters
|
||||
|
||||
|
||||
def __call__(self, input):
|
||||
layers = self.layers
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
|
@ -60,11 +60,16 @@ class ResNet():
|
|||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
stride_list = [(2,2),(2,2),(1,1),(1,1)]
|
||||
stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1")
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1")
|
||||
F = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
|
@ -79,26 +84,67 @@ class ResNet():
|
|||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=stride_list[block] if i == 0 else 1, name=conv_name)
|
||||
stride=stride_list[block] if i == 0 else 1,
|
||||
name=conv_name)
|
||||
F.append(conv)
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
|
||||
if i == 0 and block != 0:
|
||||
stride = (2, 1)
|
||||
else:
|
||||
stride = (1, 1)
|
||||
|
||||
conv = self.basic_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=stride,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name)
|
||||
F.append(conv)
|
||||
|
||||
base = F[-1]
|
||||
for i in [-2, -3]:
|
||||
b, c, w, h = F[i].shape
|
||||
if (w,h) == base.shape[2:]:
|
||||
if (w, h) == base.shape[2:]:
|
||||
base = base
|
||||
else:
|
||||
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2,
|
||||
padding=1,act=None,
|
||||
base = fluid.layers.conv2d_transpose(
|
||||
input=base,
|
||||
num_filters=c,
|
||||
filter_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
act=None,
|
||||
param_attr=w_nolr,
|
||||
bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(
|
||||
base, act="relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.concat([base, F[i]], axis=1)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(
|
||||
base,
|
||||
num_filters=c,
|
||||
filter_size=1,
|
||||
param_attr=w_nolr,
|
||||
bias_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(
|
||||
base,
|
||||
num_filters=c,
|
||||
filter_size=3,
|
||||
padding=1,
|
||||
param_attr=w_nolr,
|
||||
bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(
|
||||
base, act="relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
|
||||
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(
|
||||
base,
|
||||
num_filters=512,
|
||||
filter_size=1,
|
||||
bias_attr=w_nolr,
|
||||
param_attr=w_nolr)
|
||||
|
||||
return base
|
||||
|
||||
|
@ -113,13 +159,14 @@ class ResNet():
|
|||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size= 2 if stride==(1,1) else filter_size,
|
||||
dilation = 2 if stride==(1,1) else 1,
|
||||
filter_size=2 if stride == (1, 1) else filter_size,
|
||||
dilation=2 if stride == (1, 1) else 1,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable),
|
||||
param_attr=ParamAttr(
|
||||
name=name + "_weights", trainable=Trainable),
|
||||
bias_attr=False,
|
||||
name=name + '.conv2d.output.1')
|
||||
|
||||
|
@ -127,20 +174,23 @@ class ResNet():
|
|||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(input=conv,
|
||||
act=act,
|
||||
name=bn_name + '.output.1',
|
||||
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable),
|
||||
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance', )
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
name=bn_name + '.output.1',
|
||||
param_attr=ParamAttr(
|
||||
name=bn_name + '_scale', trainable=Trainable),
|
||||
bias_attr=ParamAttr(
|
||||
bn_name + '_offset', trainable=Trainable),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance', )
|
||||
|
||||
def shortcut(self, input, ch_out, stride, is_first, name):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1 or is_first == True:
|
||||
if stride == (1,1):
|
||||
if stride == (1, 1):
|
||||
return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
|
||||
else: #stride == (2,2)
|
||||
else: #stride == (2,2)
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
|
||||
else:
|
||||
|
@ -148,7 +198,11 @@ class ResNet():
|
|||
|
||||
def bottleneck_block(self, input, num_filters, stride, name):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
|
@ -157,16 +211,36 @@ class ResNet():
|
|||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
|
||||
input=conv1,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters * 4,
|
||||
stride,
|
||||
is_first=False,
|
||||
name=name + "_branch1")
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5")
|
||||
return fluid.layers.elementwise_add(
|
||||
x=short, y=conv2, act='relu', name=name + ".add.output.5")
|
||||
|
||||
def basic_block(self, input, num_filters, stride, is_first, name):
|
||||
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act='relu',
|
||||
stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(
|
||||
input, num_filters, stride, is_first, name=name + "_branch1")
|
||||
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
|
|
@ -26,8 +26,6 @@ class CharacterOps(object):
|
|||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
if self.loss_type == "srn" and self.character_type != "en":
|
||||
raise Exception("SRN can only support in character_type == en")
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -160,13 +158,15 @@ def cal_predicts_accuracy_srn(char_ops,
|
|||
acc_num = 0
|
||||
img_num = 0
|
||||
|
||||
char_num = char_ops.get_char_num()
|
||||
|
||||
total_len = preds.shape[0]
|
||||
img_num = int(total_len / max_text_len)
|
||||
for i in range(img_num):
|
||||
cur_label = []
|
||||
cur_pred = []
|
||||
for j in range(max_text_len):
|
||||
if labels[j + i * max_text_len] != 37: #0
|
||||
if labels[j + i * max_text_len] != int(char_num-1): #0
|
||||
cur_label.append(labels[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
|
@ -178,7 +178,7 @@ def cal_predicts_accuracy_srn(char_ops,
|
|||
elif j == len(cur_label) and j == max_text_len:
|
||||
acc_num += 1
|
||||
break
|
||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
|
||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == int(char_num-1):
|
||||
acc_num += 1
|
||||
break
|
||||
acc = acc_num * 1.0 / img_num
|
||||
|
|
|
@ -140,12 +140,12 @@ def main():
|
|||
preds = preds.reshape(-1)
|
||||
preds_text = char_ops.decode(preds)
|
||||
elif loss_type == "srn":
|
||||
cur_pred = []
|
||||
char_num = char_ops.get_char_num()
|
||||
preds = np.array(predict[0])
|
||||
preds = preds.reshape(-1)
|
||||
probs = np.array(predict[1])
|
||||
ind = np.argmax(probs, axis=1)
|
||||
valid_ind = np.where(preds != 37)[0]
|
||||
valid_ind = np.where(preds != int(char_num-1))[0]
|
||||
if len(valid_ind) == 0:
|
||||
continue
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
|
|
Loading…
Reference in New Issue