update for srn
This commit is contained in:
parent
1e8f414662
commit
09d8cb6d98
|
@ -0,0 +1,48 @@
|
|||
Global:
|
||||
algorithm: SRN
|
||||
use_gpu: true
|
||||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output/rec_pvam_withrotate
|
||||
save_epoch_step: 1
|
||||
eval_batch_step: 8000
|
||||
train_batch_size_per_card: 64
|
||||
test_batch_size_per_card: 1
|
||||
image_shape: [1, 64, 256]
|
||||
max_text_length: 25
|
||||
character_type: en
|
||||
loss_type: srn
|
||||
num_heads: 8
|
||||
average_window: 0.15
|
||||
max_average_window: 15625
|
||||
min_average_window: 10000
|
||||
reader_yml: ./configs/rec/rec_srn_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.rec_resnet50_fpn,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.rec_srn_all_head,SRNPredict
|
||||
encoder_type: rnn
|
||||
num_encoder_TUs: 2
|
||||
num_decoder_TUs: 4
|
||||
hidden_dims: 512
|
||||
SeqRNN:
|
||||
hidden_size: 256
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.rec_srn_loss,SRNLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,AdamDecay
|
||||
base_lr: 0.0001
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
|
@ -26,7 +26,7 @@ from ppocr.utils.utility import initial_logger
|
|||
from ppocr.utils.utility import get_image_file_list
|
||||
logger = initial_logger()
|
||||
|
||||
from .img_tools import process_image, get_img_data
|
||||
from .img_tools import process_image, process_image_srn, get_img_data
|
||||
|
||||
|
||||
class LMDBReader(object):
|
||||
|
@ -40,6 +40,7 @@ class LMDBReader(object):
|
|||
self.image_shape = params['image_shape']
|
||||
self.loss_type = params['loss_type']
|
||||
self.max_text_length = params['max_text_length']
|
||||
self.num_heads = params['num_heads']
|
||||
self.mode = params['mode']
|
||||
self.drop_last = False
|
||||
self.use_tps = False
|
||||
|
@ -117,14 +118,36 @@ class LMDBReader(object):
|
|||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
if self.loss_type == 'srn':
|
||||
norm_img = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length
|
||||
)
|
||||
else:
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
elif self.mode == 'test':
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
infer_mode=True
|
||||
)
|
||||
yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
|
@ -144,14 +167,16 @@ class LMDBReader(object):
|
|||
if sample_info is None:
|
||||
continue
|
||||
img, label = sample_info
|
||||
outs = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
label=label,
|
||||
char_ops=self.char_ops,
|
||||
loss_type=self.loss_type,
|
||||
max_text_length=self.max_text_length,
|
||||
distort=self.use_distort)
|
||||
outs = []
|
||||
if self.loss_type == "srn":
|
||||
outs = process_image_srn(img, self.image_shape, self.num_heads,
|
||||
self.max_text_length, label,
|
||||
self.char_ops, self.loss_type)
|
||||
|
||||
else:
|
||||
outs = process_image(img, self.image_shape, label,
|
||||
self.char_ops, self.loss_type,
|
||||
self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
@ -159,7 +184,6 @@ class LMDBReader(object):
|
|||
if finish_read_num == len(lmdb_sets):
|
||||
break
|
||||
self.close_lmdb_dataset(lmdb_sets)
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
|
@ -167,9 +191,8 @@ class LMDBReader(object):
|
|||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if not self.drop_last:
|
||||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
if len(batch_outs) != 0:
|
||||
yield batch_outs
|
||||
|
||||
if self.infer_img is None:
|
||||
return batch_iter_reader
|
||||
|
@ -288,4 +311,4 @@ class SimpleReader(object):
|
|||
|
||||
if self.infer_img is None:
|
||||
return batch_iter_reader
|
||||
return sample_iter_reader
|
||||
return sample_iter_reader
|
|
@ -381,3 +381,84 @@ def process_image(img,
|
|||
assert False, "Unsupport loss_type %s in process_image"\
|
||||
% loss_type
|
||||
return (norm_img)
|
||||
|
||||
def resize_norm_img_srn(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
img_black = np.zeros((imgH, imgW))
|
||||
im_hei = img.shape[0]
|
||||
im_wid = img.shape[1]
|
||||
|
||||
if im_wid <= im_hei * 1:
|
||||
img_new = cv2.resize(img, (imgH * 1, imgH))
|
||||
elif im_wid <= im_hei * 2:
|
||||
img_new = cv2.resize(img, (imgH * 2, imgH))
|
||||
elif im_wid <= im_hei * 3:
|
||||
img_new = cv2.resize(img, (imgH * 3, imgH))
|
||||
else:
|
||||
img_new = cv2.resize(img, (imgW, imgH))
|
||||
|
||||
img_np = np.asarray(img_new)
|
||||
img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
|
||||
img_black[:, 0:img_np.shape[1]] = img_np
|
||||
img_black = img_black[:, :, np.newaxis]
|
||||
|
||||
row, col, c = img_black.shape
|
||||
c = 1
|
||||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(image_shape,
|
||||
num_heads,
|
||||
max_text_length):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
||||
|
||||
lbl_weight = np.array([37] * max_text_length).reshape((-1,1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
|
||||
|
||||
def process_image_srn(img,
|
||||
image_shape,
|
||||
num_heads,
|
||||
max_text_length,
|
||||
label=None,
|
||||
char_ops=None,
|
||||
loss_type=None):
|
||||
norm_img = resize_norm_img_srn(img, image_shape)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
[lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
|
||||
srn_other_inputs(image_shape, num_heads, max_text_length)
|
||||
|
||||
if label is not None:
|
||||
char_num = char_ops.get_char_num()
|
||||
text = char_ops.encode(label)
|
||||
if len(text) == 0 or len(text) > max_text_length:
|
||||
return None
|
||||
else:
|
||||
if loss_type == "srn":
|
||||
text_padded = [37] * max_text_length
|
||||
for i in range(len(text)):
|
||||
text_padded[i] = text[i]
|
||||
lbl_weight[i] = [1.0]
|
||||
text_padded = np.array(text_padded)
|
||||
text = text_padded.reshape(-1, 1)
|
||||
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
|
||||
else:
|
||||
assert False, "Unsupport loss_type %s in process_image"\
|
||||
% loss_type
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
|
||||
|
|
|
@ -58,6 +58,7 @@ class RecModel(object):
|
|||
self.loss_type = global_params['loss_type']
|
||||
self.image_shape = global_params['image_shape']
|
||||
self.max_text_length = global_params['max_text_length']
|
||||
self.num_heads = global_params["num_heads"]
|
||||
|
||||
def create_feed(self, mode):
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
|
@ -77,6 +78,18 @@ class RecModel(object):
|
|||
lod_level=1)
|
||||
feed_list = [image, label_in, label_out]
|
||||
labels = {'label_in': label_in, 'label_out': label_out}
|
||||
elif self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
|
||||
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64')
|
||||
label = fluid.data(
|
||||
name='label', shape=[-1, 1], dtype='int32', lod_level=1)
|
||||
feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight]
|
||||
labels = {'label': label, 'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight}
|
||||
else:
|
||||
label = fluid.data(
|
||||
name='label', shape=[None, 1], dtype='int32', lod_level=1)
|
||||
|
@ -88,6 +101,8 @@ class RecModel(object):
|
|||
use_double_buffer=True,
|
||||
iterable=False)
|
||||
else:
|
||||
labels = None
|
||||
loader = None
|
||||
if self.char_type == "ch" and self.infer_img:
|
||||
image_shape[-1] = -1
|
||||
if self.tps != None:
|
||||
|
@ -97,9 +112,15 @@ class RecModel(object):
|
|||
"We set img_shape to be the same , it may affect the inference effect"
|
||||
)
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
labels = None
|
||||
loader = None
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
if self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
|
||||
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
|
||||
labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2}
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
|
@ -117,9 +138,15 @@ class RecModel(object):
|
|||
label = labels['label_out']
|
||||
else:
|
||||
label = labels['label']
|
||||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
if self.loss_type == 'srn':
|
||||
total_loss, img_loss, word_loss = self.loss(predicts, labels)
|
||||
outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss,
|
||||
'decoded_out':decoded_out, 'label':label}
|
||||
else:
|
||||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
return loader, outputs
|
||||
|
||||
elif mode == "export":
|
||||
predict = predicts['predict']
|
||||
if self.loss_type == "ctc":
|
||||
|
@ -129,4 +156,4 @@ class RecModel(object):
|
|||
predict = predicts['predict']
|
||||
if self.loss_type == "ctc":
|
||||
predict = fluid.layers.softmax(predict)
|
||||
return loader, {'decoded_out': decoded_out, 'predicts': predict}
|
||||
return loader, {'decoded_out': decoded_out, 'predicts': predict}
|
|
@ -0,0 +1,172 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
|
||||
__all__ = ["ResNet", "ResNet18", "ResNet34", "ResNet50", "ResNet101", "ResNet152"]
|
||||
|
||||
Trainable = True
|
||||
w_nolr = fluid.ParamAttr(
|
||||
trainable = Trainable)
|
||||
train_parameters = {
|
||||
"input_size": [3, 224, 224],
|
||||
"input_mean": [0.485, 0.456, 0.406],
|
||||
"input_std": [0.229, 0.224, 0.225],
|
||||
"learning_strategy": {
|
||||
"name": "piecewise_decay",
|
||||
"batch_size": 256,
|
||||
"epochs": [30, 60, 90],
|
||||
"steps": [0.1, 0.01, 0.001, 0.0001]
|
||||
}
|
||||
}
|
||||
|
||||
class ResNet():
|
||||
def __init__(self, params):
|
||||
self.layers = params['layers']
|
||||
self.params = train_parameters
|
||||
|
||||
|
||||
def __call__(self, input):
|
||||
layers = self.layers
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
stride_list = [(2,2),(2,2),(1,1),(1,1)]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
conv = self.conv_bn_layer(
|
||||
input=input, num_filters=64, filter_size=7, stride=2, act='relu', name="conv1")
|
||||
F = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=stride_list[block] if i == 0 else 1, name=conv_name)
|
||||
F.append(conv)
|
||||
|
||||
base = F[-1]
|
||||
for i in [-2, -3]:
|
||||
b, c, w, h = F[i].shape
|
||||
if (w,h) == base.shape[2:]:
|
||||
base = base
|
||||
else:
|
||||
base = fluid.layers.conv2d_transpose( input=base, num_filters=c,filter_size=4, stride=2,
|
||||
padding=1,act=None,
|
||||
param_attr=w_nolr,
|
||||
bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.concat([base, F[i]], axis=1)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.conv2d(base, num_filters=c, filter_size=3,padding = 1, param_attr=w_nolr, bias_attr=w_nolr)
|
||||
base = fluid.layers.batch_norm(base, act = "relu", param_attr=w_nolr, bias_attr=w_nolr)
|
||||
|
||||
base = fluid.layers.conv2d(base, num_filters=512, filter_size=1,bias_attr=w_nolr,param_attr=w_nolr)
|
||||
|
||||
return base
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size= 2 if stride==(1,1) else filter_size,
|
||||
dilation = 2 if stride==(1,1) else 1,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights",trainable = Trainable),
|
||||
bias_attr=False,
|
||||
name=name + '.conv2d.output.1')
|
||||
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(input=conv,
|
||||
act=act,
|
||||
name=bn_name + '.output.1',
|
||||
param_attr=ParamAttr(name=bn_name + '_scale',trainable = Trainable),
|
||||
bias_attr=ParamAttr(bn_name + '_offset',trainable = Trainable),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance', )
|
||||
|
||||
def shortcut(self, input, ch_out, stride, is_first, name):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1 or is_first == True:
|
||||
if stride == (1,1):
|
||||
return self.conv_bn_layer(input, ch_out, 1, 1, name=name)
|
||||
else: #stride == (2,2)
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
|
||||
else:
|
||||
return input
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, name):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input, num_filters=num_filters, filter_size=1, act='relu', name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1, num_filters=num_filters * 4, filter_size=1, act=None, name=name + "_branch2c")
|
||||
|
||||
short = self.shortcut(input, num_filters * 4, stride, is_first=False, name=name + "_branch1")
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu', name=name + ".add.output.5")
|
||||
|
||||
def basic_block(self, input, num_filters, stride, is_first, name):
|
||||
conv0 = self.conv_bn_layer(input=input, num_filters=num_filters, filter_size=3, act='relu', stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(input=conv0, num_filters=num_filters, filter_size=3, act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(input, num_filters, stride, is_first, name=name + "_branch1")
|
||||
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
|
|
@ -32,7 +32,7 @@ class ResNet():
|
|||
def __init__(self, params):
|
||||
self.layers = params['layers']
|
||||
self.is_3x3 = True
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert self.layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
|
||||
|
||||
|
|
|
@ -0,0 +1,218 @@
|
|||
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
#from .rec_seq_encoder import SequenceEncoder
|
||||
#from ..common_functions import get_para_bias_attr
|
||||
import numpy as np
|
||||
from .self_attention.model import wrap_encoder
|
||||
from .self_attention.model import wrap_encoder_forFeature
|
||||
gradient_clip = 10
|
||||
|
||||
|
||||
|
||||
class SRNPredict(object):
|
||||
def __init__(self, params):
|
||||
super(SRNPredict, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
self.max_length = params['max_text_length']
|
||||
|
||||
self.num_heads = params['num_heads']
|
||||
self.num_encoder_TUs = params['num_encoder_TUs']
|
||||
self.num_decoder_TUs = params['num_decoder_TUs']
|
||||
self.hidden_dims = params['hidden_dims']
|
||||
|
||||
|
||||
def pvam(self, inputs, others):
|
||||
|
||||
b, c, h, w = inputs.shape
|
||||
conv_features = fluid.layers.reshape(x=inputs, shape=[-1, c, h * w])
|
||||
conv_features = fluid.layers.transpose(x=conv_features, perm=[0, 2, 1])
|
||||
|
||||
#===== Transformer encoder =====
|
||||
b, t, c = conv_features.shape
|
||||
encoder_word_pos = others["encoder_word_pos"]
|
||||
gsrm_word_pos = others["gsrm_word_pos"]
|
||||
|
||||
|
||||
enc_inputs = [conv_features, encoder_word_pos, None]
|
||||
word_features = wrap_encoder_forFeature(src_vocab_size=-1,
|
||||
max_length=t,
|
||||
n_layer=self.num_encoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key= int(self.hidden_dims / self.num_heads),
|
||||
d_value= int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs,
|
||||
)
|
||||
fluid.clip.set_gradient_clip(fluid.clip.GradientClipByValue(gradient_clip))
|
||||
|
||||
#===== Parallel Visual Attention Module =====
|
||||
b, t, c = word_features.shape
|
||||
|
||||
word_features = fluid.layers.fc(word_features, c, num_flatten_dims=2)
|
||||
word_features_ = fluid.layers.reshape(word_features, [-1, 1, t, c])
|
||||
word_features_ = fluid.layers.expand(word_features_, [1, self.max_length, 1, 1])
|
||||
word_pos_feature = fluid.layers.embedding(gsrm_word_pos, [self.max_length, c])
|
||||
word_pos_ = fluid.layers.reshape(word_pos_feature, [-1, self.max_length, 1, c])
|
||||
word_pos_ = fluid.layers.expand(word_pos_, [1, 1, t, 1])
|
||||
temp = fluid.layers.elementwise_add(word_features_, word_pos_, act='tanh')
|
||||
|
||||
attention_weight = fluid.layers.fc(input=temp, size=1, num_flatten_dims=3, bias_attr=False)
|
||||
attention_weight = fluid.layers.reshape(x=attention_weight, shape=[-1, self.max_length, t])
|
||||
attention_weight = fluid.layers.softmax(input=attention_weight, axis=-1)
|
||||
|
||||
pvam_features = fluid.layers.matmul(attention_weight, word_features)#[b, max_length, c]
|
||||
|
||||
return pvam_features
|
||||
|
||||
def gsrm(self, pvam_features, others):
|
||||
|
||||
#===== GSRM Visual-to-semantic embedding block =====
|
||||
b, t, c = pvam_features.shape
|
||||
word_out = fluid.layers.fc(input=fluid.layers.reshape(pvam_features, [-1, c]),
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
#word_out.stop_gradient = True
|
||||
word_ids = fluid.layers.argmax(word_out, axis=1)
|
||||
word_ids.stop_gradient = True
|
||||
word_ids = fluid.layers.reshape(x=word_ids, shape=[-1, t, 1])
|
||||
|
||||
#===== GSRM Semantic reasoning block =====
|
||||
"""
|
||||
This module is achieved through bi-transformers,
|
||||
ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
|
||||
"""
|
||||
pad_idx = self.char_num
|
||||
gsrm_word_pos = others["gsrm_word_pos"]
|
||||
gsrm_slf_attn_bias1 = others["gsrm_slf_attn_bias1"]
|
||||
gsrm_slf_attn_bias2 = others["gsrm_slf_attn_bias2"]
|
||||
|
||||
def prepare_bi(word_ids):
|
||||
"""
|
||||
prepare bi for gsrm
|
||||
word1 for forward; word2 for backward
|
||||
"""
|
||||
word1 = fluid.layers.cast(word_ids, "float32")
|
||||
word1 = fluid.layers.pad(word1, [0, 0, 1, 0, 0, 0], pad_value=1.0 * pad_idx)
|
||||
word1 = fluid.layers.cast(word1, "int64")
|
||||
word1 = word1[:, :-1, :]
|
||||
word2 = word_ids
|
||||
return word1, word2
|
||||
|
||||
word1, word2 = prepare_bi(word_ids)
|
||||
word1.stop_gradient = True
|
||||
word2.stop_gradient = True
|
||||
enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
|
||||
enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
|
||||
|
||||
gsrm_feature1 = wrap_encoder(src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_1,
|
||||
)
|
||||
gsrm_feature2 = wrap_encoder(src_vocab_size=self.char_num + 1,
|
||||
max_length=self.max_length,
|
||||
n_layer=self.num_decoder_TUs,
|
||||
n_head=self.num_heads,
|
||||
d_key=int(self.hidden_dims / self.num_heads),
|
||||
d_value=int(self.hidden_dims / self.num_heads),
|
||||
d_model=self.hidden_dims,
|
||||
d_inner_hid=self.hidden_dims,
|
||||
prepostprocess_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
relu_dropout=0.1,
|
||||
preprocess_cmd="n",
|
||||
postprocess_cmd="da",
|
||||
weight_sharing=True,
|
||||
enc_inputs=enc_inputs_2,
|
||||
)
|
||||
gsrm_feature2 = fluid.layers.pad(gsrm_feature2, [0, 0, 0, 1, 0, 0], pad_value=0.)
|
||||
gsrm_feature2 = gsrm_feature2[:, 1:, ]
|
||||
gsrm_features = gsrm_feature1 + gsrm_feature2
|
||||
|
||||
b, t, c = gsrm_features.shape
|
||||
|
||||
gsrm_out = fluid.layers.matmul(
|
||||
x=gsrm_features,
|
||||
y=fluid.default_main_program().global_block().var("src_word_emb_table"),
|
||||
transpose_y=True)
|
||||
b,t,c = gsrm_out.shape
|
||||
gsrm_out = fluid.layers.softmax(input=fluid.layers.reshape(gsrm_out, [-1, c]))
|
||||
|
||||
return gsrm_features, word_out, gsrm_out
|
||||
|
||||
def vsfd(self, pvam_features, gsrm_features):
|
||||
|
||||
#===== Visual-Semantic Fusion Decoder Module =====
|
||||
b, t, c1 = pvam_features.shape
|
||||
b, t, c2 = gsrm_features.shape
|
||||
combine_features_ = fluid.layers.concat([pvam_features, gsrm_features], axis=2)
|
||||
img_comb_features_ = fluid.layers.reshape(x=combine_features_, shape=[-1, c1 + c2])
|
||||
img_comb_features_map = fluid.layers.fc(input=img_comb_features_, size=c1, act="sigmoid")
|
||||
img_comb_features_map = fluid.layers.reshape(x=img_comb_features_map, shape=[-1, t, c1])
|
||||
combine_features = img_comb_features_map * pvam_features + (1.0 - img_comb_features_map) * gsrm_features
|
||||
img_comb_features = fluid.layers.reshape(x=combine_features, shape=[-1, c1])
|
||||
|
||||
fc_out = fluid.layers.fc(input=img_comb_features,
|
||||
size=self.char_num,
|
||||
act="softmax")
|
||||
return fc_out
|
||||
|
||||
|
||||
def __call__(self, inputs, others, mode=None):
|
||||
|
||||
pvam_features = self.pvam(inputs, others)
|
||||
gsrm_features, word_out, gsrm_out = self.gsrm(pvam_features, others)
|
||||
final_out = self.vsfd(pvam_features, gsrm_features)
|
||||
|
||||
_, decoded_out = fluid.layers.topk(input=final_out, k=1)
|
||||
predicts = {'predict': final_out, 'decoded_out': decoded_out,
|
||||
'word_out': word_out, 'gsrm_out': gsrm_out}
|
||||
|
||||
return predicts
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,58 @@
|
|||
#copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class SRNLoss(object):
|
||||
def __init__(self, params):
|
||||
super(SRNLoss, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
|
||||
def __call__(self, predicts, others):
|
||||
predict = predicts['predict']
|
||||
word_predict = predicts['word_out']
|
||||
gsrm_predict = predicts['gsrm_out']
|
||||
label = others['label']
|
||||
lbl_weight = others['lbl_weight']
|
||||
|
||||
casted_label = fluid.layers.cast(x=label, dtype='int64')
|
||||
cost_word = fluid.layers.cross_entropy(input=word_predict, label=casted_label)
|
||||
cost_gsrm = fluid.layers.cross_entropy(input=gsrm_predict, label=casted_label)
|
||||
cost_vsfd = fluid.layers.cross_entropy(input=predict, label=casted_label)
|
||||
|
||||
#cost_word = cost_word * lbl_weight
|
||||
#cost_gsrm = cost_gsrm * lbl_weight
|
||||
#cost_vsfd = cost_vsfd * lbl_weight
|
||||
|
||||
cost_word = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_word), shape=[1])
|
||||
cost_gsrm = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_gsrm), shape=[1])
|
||||
cost_vsfd = fluid.layers.reshape(x=fluid.layers.reduce_sum(cost_vsfd), shape=[1])
|
||||
|
||||
sum_cost = fluid.layers.sum([cost_word, cost_vsfd * 2.0, cost_gsrm * 0.15])
|
||||
|
||||
#sum_cost = fluid.layers.sum([cost_word * 3.0, cost_vsfd, cost_gsrm * 0.15])
|
||||
#sum_cost = cost_word
|
||||
|
||||
#fluid.layers.Print(cost_word,message="word_cost")
|
||||
#fluid.layers.Print(cost_vsfd,message="img_cost")
|
||||
return [sum_cost,cost_vsfd,cost_word]
|
||||
#return [sum_cost, cost_vsfd, cost_word]
|
|
@ -25,6 +25,7 @@ class CharacterOps(object):
|
|||
def __init__(self, config):
|
||||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -54,6 +55,8 @@ class CharacterOps(object):
|
|||
self.end_str = "eos"
|
||||
if self.loss_type == "attention":
|
||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
||||
elif self.loss_type == "srn":
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
|
@ -146,6 +149,48 @@ def cal_predicts_accuracy(char_ops,
|
|||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
def cal_predicts_accuracy_srn(char_ops,
|
||||
preds,
|
||||
labels,
|
||||
max_text_len,
|
||||
is_debug=False):
|
||||
acc_num = 0
|
||||
img_num = 0
|
||||
|
||||
total_len = preds.shape[0]
|
||||
img_num = int(total_len / max_text_len)
|
||||
#print (img_num)
|
||||
for i in range(img_num):
|
||||
cur_label = []
|
||||
cur_pred = []
|
||||
for j in range(max_text_len):
|
||||
if labels[j + i * max_text_len] != 37: #0
|
||||
cur_label.append(labels[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
|
||||
if is_debug:
|
||||
for j in range(max_text_len):
|
||||
if preds[j + i * max_text_len] != 37: #0
|
||||
cur_pred.append(preds[j + i * max_text_len][0])
|
||||
else:
|
||||
break
|
||||
print ("cur_label: ", cur_label)
|
||||
print ("cur_pred: ", cur_pred)
|
||||
|
||||
|
||||
for j in range(max_text_len + 1):
|
||||
if j < len(cur_label) and preds[j + i * max_text_len][0] != cur_label[j]:
|
||||
break
|
||||
elif j == len(cur_label) and j == max_text_len:
|
||||
acc_num += 1
|
||||
break
|
||||
elif j == len(cur_label) and preds[j + i * max_text_len][0] == 37:
|
||||
acc_num += 1
|
||||
break
|
||||
acc = acc_num * 1.0 / img_num
|
||||
return acc, acc_num, img_num
|
||||
|
||||
|
||||
def convert_rec_attention_infer_res(preds):
|
||||
img_num = preds.shape[0]
|
||||
|
|
|
@ -29,7 +29,7 @@ FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
|||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from ppocr.utils.character import cal_predicts_accuracy
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn
|
||||
from ppocr.utils.character import convert_rec_label_to_lod
|
||||
from ppocr.utils.character import convert_rec_attention_infer_res
|
||||
from ppocr.utils.utility import create_module
|
||||
|
@ -60,19 +60,52 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
for ino in range(img_num):
|
||||
img_list.append(data[ino][0])
|
||||
label_list.append(data[ino][1])
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
|
||||
if config['Global']['loss_type'] != "srn":
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
preds = np.array(outs[0])
|
||||
if preds.shape[1] != 1:
|
||||
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
||||
preds = np.array(outs[0])
|
||||
|
||||
if preds.shape[1] != 1:
|
||||
preds, preds_lod = convert_rec_attention_infer_res(preds)
|
||||
else:
|
||||
preds_lod = outs[0].lod()[0]
|
||||
labels, labels_lod = convert_rec_label_to_lod(label_list)
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy(
|
||||
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
|
||||
else:
|
||||
preds_lod = outs[0].lod()[0]
|
||||
labels, labels_lod = convert_rec_label_to_lod(label_list)
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy(
|
||||
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
gsrm_slf_attn_bias1_list = []
|
||||
gsrm_slf_attn_bias2_list = []
|
||||
for ino in range(img_num):
|
||||
encoder_word_pos_list.append(data[ino][2])
|
||||
gsrm_word_pos_list.append(data[ino][3])
|
||||
gsrm_slf_attn_bias1_list.append(data[ino][4])
|
||||
gsrm_slf_attn_bias2_list.append(data[ino][5])
|
||||
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
label_list = np.concatenate(label_list, axis=0)
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
|
||||
labels = label_list
|
||||
|
||||
outs = exe.run(eval_info_dict['program'], \
|
||||
feed={'image': img_list, 'encoder_word_pos': encoder_word_pos_list,
|
||||
'gsrm_word_pos': gsrm_word_pos_list, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1_list,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2_list}, \
|
||||
fetch_list=eval_info_dict['fetch_varname_list'], \
|
||||
return_numpy=False)
|
||||
preds = np.array(outs[0])
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy_srn(
|
||||
char_ops, preds, labels, config['Global']['max_text_length'])
|
||||
|
||||
total_acc_num += acc_num
|
||||
total_sample_num += sample_num
|
||||
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
|
@ -85,8 +118,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
|
||||
def test_rec_benchmark(exe, config, eval_info_dict):
|
||||
" Evaluate lmdb dataset "
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
|
||||
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', \
|
||||
'IC13_857', 'IC15_1811', 'IC15_2077','SVTP', 'CUTE80']
|
||||
eval_data_dir = config['TestReader']['lmdb_sets_dir']
|
||||
total_evaluation_data_number = 0
|
||||
total_correct_number = 0
|
||||
|
|
|
@ -32,7 +32,7 @@ from eval_utils.eval_det_utils import eval_det_run
|
|||
from eval_utils.eval_rec_utils import eval_rec_run
|
||||
from ppocr.utils.save_load import save_model
|
||||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, CharacterOps
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
|
@ -176,8 +176,16 @@ def build(config, main_prog, startup_prog, mode):
|
|||
fetch_name_list = list(outputs.keys())
|
||||
fetch_varname_list = [outputs[v].name for v in fetch_name_list]
|
||||
opt_loss_name = None
|
||||
model_average = None
|
||||
img_loss_name = None
|
||||
word_loss_name = None
|
||||
if mode == "train":
|
||||
opt_loss = outputs['total_loss']
|
||||
# srn loss
|
||||
#img_loss = outputs['img_loss']
|
||||
#word_loss = outputs['word_loss']
|
||||
#img_loss_name = img_loss.name
|
||||
#word_loss_name = word_loss.name
|
||||
opt_params = config['Optimizer']
|
||||
optimizer = create_module(opt_params['function'])(opt_params)
|
||||
optimizer.minimize(opt_loss)
|
||||
|
@ -185,7 +193,13 @@ def build(config, main_prog, startup_prog, mode):
|
|||
global_lr = optimizer._global_learning_rate()
|
||||
fetch_name_list.insert(0, "lr")
|
||||
fetch_varname_list.insert(0, global_lr.name)
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
|
||||
if config['Global']["loss_type"] == 'srn':
|
||||
model_average = fluid.optimizer.ModelAverage(
|
||||
config['Global']['average_window'],
|
||||
min_average_window=config['Global']['min_average_window'],
|
||||
max_average_window=config['Global']['max_average_window'])
|
||||
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average)
|
||||
|
||||
|
||||
def build_export(config, main_prog, startup_prog):
|
||||
|
@ -329,14 +343,20 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
lr = np.mean(np.array(train_outs[fetch_map['lr']]))
|
||||
preds_idx = fetch_map['decoded_out']
|
||||
preds = np.array(train_outs[preds_idx])
|
||||
preds_lod = train_outs[preds_idx].lod()[0]
|
||||
labels_idx = fetch_map['label']
|
||||
labels = np.array(train_outs[labels_idx])
|
||||
labels_lod = train_outs[labels_idx].lod()[0]
|
||||
|
||||
acc, acc_num, img_num = cal_predicts_accuracy(
|
||||
config['Global']['char_ops'], preds, preds_lod, labels,
|
||||
labels_lod)
|
||||
if config['Global']['loss_type'] != 'srn':
|
||||
preds_lod = train_outs[preds_idx].lod()[0]
|
||||
labels_lod = train_outs[labels_idx].lod()[0]
|
||||
|
||||
acc, acc_num, img_num = cal_predicts_accuracy(
|
||||
config['Global']['char_ops'], preds, preds_lod, labels,
|
||||
labels_lod)
|
||||
else:
|
||||
acc, acc_num, img_num = cal_predicts_accuracy_srn(
|
||||
config['Global']['char_ops'], preds, labels,
|
||||
config['Global']['max_text_length'])
|
||||
t2 = time.time()
|
||||
train_batch_elapse = t2 - t1
|
||||
stats = {'loss': loss, 'acc': acc}
|
||||
|
@ -350,6 +370,9 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
|
||||
if train_batch_id > 0 and\
|
||||
train_batch_id % eval_batch_step == 0:
|
||||
model_average = train_info_dict['model_average']
|
||||
if model_average != None:
|
||||
model_average.apply(exe)
|
||||
metrics = eval_rec_run(exe, config, eval_info_dict, "eval")
|
||||
eval_acc = metrics['avg_acc']
|
||||
eval_sample_num = metrics['total_sample_num']
|
||||
|
|
|
@ -52,6 +52,7 @@ def main():
|
|||
train_fetch_name_list = train_build_outputs[1]
|
||||
train_fetch_varname_list = train_build_outputs[2]
|
||||
train_opt_loss_name = train_build_outputs[3]
|
||||
model_average = train_build_outputs[-1]
|
||||
|
||||
eval_program = fluid.Program()
|
||||
eval_build_outputs = program.build(
|
||||
|
@ -85,7 +86,8 @@ def main():
|
|||
'train_program':train_program,\
|
||||
'reader':train_loader,\
|
||||
'fetch_name_list':train_fetch_name_list,\
|
||||
'fetch_varname_list':train_fetch_varname_list}
|
||||
'fetch_varname_list':train_fetch_varname_list,\
|
||||
'model_average': model_average}
|
||||
|
||||
eval_info_dict = {'program':eval_program,\
|
||||
'reader':eval_reader,\
|
||||
|
|
Loading…
Reference in New Issue