update config
This commit is contained in:
parent
09d8cb6d98
commit
6832ca029f
|
@ -17,10 +17,11 @@ Global:
|
|||
average_window: 0.15
|
||||
max_average_window: 15625
|
||||
min_average_window: 10000
|
||||
reader_yml: ./configs/rec/rec_srn_reader.yml
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
infer_img:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -118,15 +118,14 @@ class LMDBReader(object):
|
|||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
if img.shape[-1] == 1 or len(list(img.shape)) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
if self.loss_type == 'srn':
|
||||
norm_img = process_image_srn(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
num_heads=self.num_heads,
|
||||
max_text_length=self.max_text_length
|
||||
)
|
||||
max_text_length=self.max_text_length)
|
||||
else:
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
|
@ -135,20 +134,20 @@ class LMDBReader(object):
|
|||
tps=self.use_tps,
|
||||
infer_mode=True)
|
||||
yield norm_img
|
||||
elif self.mode == 'test':
|
||||
image_file_list = get_image_file_list(self.infer_img)
|
||||
for single_img in image_file_list:
|
||||
img = cv2.imread(single_img)
|
||||
if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
norm_img = process_image(
|
||||
img=img,
|
||||
image_shape=self.image_shape,
|
||||
char_ops=self.char_ops,
|
||||
tps=self.use_tps,
|
||||
infer_mode=True
|
||||
)
|
||||
yield norm_img
|
||||
#elif self.mode == 'eval':
|
||||
# image_file_list = get_image_file_list(self.infer_img)
|
||||
# for single_img in image_file_list:
|
||||
# img = cv2.imread(single_img)
|
||||
# if img.shape[-1]==1 or len(list(img.shape))==2:
|
||||
# img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
# norm_img = process_image(
|
||||
# img=img,
|
||||
# image_shape=self.image_shape,
|
||||
# char_ops=self.char_ops,
|
||||
# tps=self.use_tps,
|
||||
# infer_mode=True
|
||||
# )
|
||||
# yield norm_img
|
||||
else:
|
||||
lmdb_sets = self.load_hierarchical_lmdb_dataset()
|
||||
if process_id == 0:
|
||||
|
@ -169,14 +168,15 @@ class LMDBReader(object):
|
|||
img, label = sample_info
|
||||
outs = []
|
||||
if self.loss_type == "srn":
|
||||
outs = process_image_srn(img, self.image_shape, self.num_heads,
|
||||
self.max_text_length, label,
|
||||
self.char_ops, self.loss_type)
|
||||
outs = process_image_srn(
|
||||
img, self.image_shape, self.num_heads,
|
||||
self.max_text_length, label, self.char_ops,
|
||||
self.loss_type)
|
||||
|
||||
else:
|
||||
outs = process_image(img, self.image_shape, label,
|
||||
self.char_ops, self.loss_type,
|
||||
self.max_text_length)
|
||||
outs = process_image(
|
||||
img, self.image_shape, label, self.char_ops,
|
||||
self.loss_type, self.max_text_length)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
@ -184,6 +184,7 @@ class LMDBReader(object):
|
|||
if finish_read_num == len(lmdb_sets):
|
||||
break
|
||||
self.close_lmdb_dataset(lmdb_sets)
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
|
|
|
@ -79,17 +79,45 @@ class RecModel(object):
|
|||
feed_list = [image, label_in, label_out]
|
||||
labels = {'label_in': label_in, 'label_out': label_out}
|
||||
elif self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
|
||||
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
lbl_weight = fluid.layers.data(name="lbl_weight", shape=[-1, 1], dtype='int64')
|
||||
encoder_word_pos = fluid.data(
|
||||
name="encoder_word_pos",
|
||||
shape=[
|
||||
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
|
||||
1
|
||||
],
|
||||
dtype="int64")
|
||||
gsrm_word_pos = fluid.data(
|
||||
name="gsrm_word_pos",
|
||||
shape=[-1, self.max_text_length, 1],
|
||||
dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(
|
||||
name="gsrm_slf_attn_bias1",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
lbl_weight = fluid.layers.data(
|
||||
name="lbl_weight", shape=[-1, 1], dtype='int64')
|
||||
label = fluid.data(
|
||||
name='label', shape=[-1, 1], dtype='int32', lod_level=1)
|
||||
feed_list = [image, label, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight]
|
||||
labels = {'label': label, 'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos, 'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,'lbl_weight':lbl_weight}
|
||||
feed_list = [
|
||||
image, label, encoder_word_pos, gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight
|
||||
]
|
||||
labels = {
|
||||
'label': label,
|
||||
'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2,
|
||||
'lbl_weight': lbl_weight
|
||||
}
|
||||
else:
|
||||
label = fluid.data(
|
||||
name='label', shape=[None, 1], dtype='int32', lod_level=1)
|
||||
|
@ -112,15 +140,41 @@ class RecModel(object):
|
|||
"We set img_shape to be the same , it may affect the inference effect"
|
||||
)
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
image = fluid.data(name='image', shape=image_shape, dtype='float32')
|
||||
if self.loss_type == "srn":
|
||||
encoder_word_pos = fluid.data(name="encoder_word_pos", shape=[-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)), 1], dtype="int64")
|
||||
gsrm_word_pos = fluid.data(name="gsrm_word_pos", shape=[-1, self.max_text_length, 1], dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(name="gsrm_slf_attn_bias1", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
gsrm_slf_attn_bias2 = fluid.data(name="gsrm_slf_attn_bias2", shape=[-1, self.num_heads, self.max_text_length, self.max_text_length])
|
||||
feed_list = [image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
|
||||
labels = {'encoder_word_pos': encoder_word_pos, 'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1, 'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2}
|
||||
encoder_word_pos = fluid.data(
|
||||
name="encoder_word_pos",
|
||||
shape=[
|
||||
-1, int((image_shape[-2] / 8) * (image_shape[-1] / 8)),
|
||||
1
|
||||
],
|
||||
dtype="int64")
|
||||
gsrm_word_pos = fluid.data(
|
||||
name="gsrm_word_pos",
|
||||
shape=[-1, self.max_text_length, 1],
|
||||
dtype="int64")
|
||||
gsrm_slf_attn_bias1 = fluid.data(
|
||||
name="gsrm_slf_attn_bias1",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
gsrm_slf_attn_bias2 = fluid.data(
|
||||
name="gsrm_slf_attn_bias2",
|
||||
shape=[
|
||||
-1, self.num_heads, self.max_text_length,
|
||||
self.max_text_length
|
||||
])
|
||||
feed_list = [
|
||||
image, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
labels = {
|
||||
'encoder_word_pos': encoder_word_pos,
|
||||
'gsrm_word_pos': gsrm_word_pos,
|
||||
'gsrm_slf_attn_bias1': gsrm_slf_attn_bias1,
|
||||
'gsrm_slf_attn_bias2': gsrm_slf_attn_bias2
|
||||
}
|
||||
return image, labels, loader
|
||||
|
||||
def __call__(self, mode):
|
||||
|
@ -140,8 +194,13 @@ class RecModel(object):
|
|||
label = labels['label']
|
||||
if self.loss_type == 'srn':
|
||||
total_loss, img_loss, word_loss = self.loss(predicts, labels)
|
||||
outputs = {'total_loss':total_loss, 'img_loss':img_loss, 'word_loss':word_loss,
|
||||
'decoded_out':decoded_out, 'label':label}
|
||||
outputs = {
|
||||
'total_loss': total_loss,
|
||||
'img_loss': img_loss,
|
||||
'word_loss': word_loss,
|
||||
'decoded_out': decoded_out,
|
||||
'label': label
|
||||
}
|
||||
else:
|
||||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
|
|
|
@ -4,8 +4,9 @@ import numpy as np
|
|||
import paddle.fluid as fluid
|
||||
import paddle.fluid.layers as layers
|
||||
|
||||
from .desc import *
|
||||
from .config import ModelHyperParams,TrainTaskConfig
|
||||
# Set seed for CE
|
||||
dropout_seed = None
|
||||
|
||||
|
||||
def wrap_layer_with_block(layer, block_idx):
|
||||
"""
|
||||
|
@ -269,23 +270,24 @@ pre_process_layer = partial(pre_post_process_layer, None)
|
|||
post_process_layer = pre_post_process_layer
|
||||
|
||||
|
||||
def prepare_encoder(src_word,#[b,t,c]
|
||||
src_pos,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0.,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
def prepare_encoder(
|
||||
src_word, #[b,t,c]
|
||||
src_pos,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0.,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
"""Add word embeddings and position encodings.
|
||||
The output tensor has a shape of:
|
||||
[batch_size, max_src_length_in_batch, d_model].
|
||||
This module is used at the bottom of the encoder stacks.
|
||||
"""
|
||||
|
||||
src_word_emb =src_word#layers.concat(res,axis=1)
|
||||
src_word_emb=layers.cast(src_word_emb,'float32')
|
||||
src_word_emb = src_word #layers.concat(res,axis=1)
|
||||
src_word_emb = layers.cast(src_word_emb, 'float32')
|
||||
# print("src_word_emb",src_word_emb)
|
||||
|
||||
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
|
||||
|
@ -302,14 +304,14 @@ def prepare_encoder(src_word,#[b,t,c]
|
|||
|
||||
|
||||
def prepare_decoder(src_word,
|
||||
src_pos,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0.,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
src_pos,
|
||||
src_vocab_size,
|
||||
src_emb_dim,
|
||||
src_max_len,
|
||||
dropout_rate=0.,
|
||||
bos_idx=0,
|
||||
word_emb_param_name=None,
|
||||
pos_enc_param_name=None):
|
||||
"""Add word embeddings and position encodings.
|
||||
The output tensor has a shape of:
|
||||
[batch_size, max_src_length_in_batch, d_model].
|
||||
|
@ -323,7 +325,7 @@ def prepare_decoder(src_word,
|
|||
name=word_emb_param_name,
|
||||
initializer=fluid.initializer.Normal(0., src_emb_dim**-0.5)))
|
||||
# print("target_word_emb",src_word_emb)
|
||||
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim ** 0.5)
|
||||
src_word_emb = layers.scale(x=src_word_emb, scale=src_emb_dim**0.5)
|
||||
src_pos_enc = layers.embedding(
|
||||
src_pos,
|
||||
size=[src_max_len, src_emb_dim],
|
||||
|
@ -335,6 +337,7 @@ def prepare_decoder(src_word,
|
|||
enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
|
||||
is_test=False) if dropout_rate else enc_input
|
||||
|
||||
|
||||
# prepare_encoder = partial(
|
||||
# prepare_encoder_decoder, pos_enc_param_name=pos_enc_param_names[0])
|
||||
# prepare_decoder = partial(
|
||||
|
@ -595,21 +598,9 @@ def transformer(src_vocab_size,
|
|||
weights = all_inputs[-1]
|
||||
|
||||
enc_output = wrap_encoder(
|
||||
src_vocab_size,
|
||||
ModelHyperParams.src_seq_len,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd,
|
||||
postprocess_cmd,
|
||||
weight_sharing,
|
||||
enc_inputs)
|
||||
src_vocab_size, 64, n_layer, n_head, d_key, d_value, d_model,
|
||||
d_inner_hid, prepostprocess_dropout, attention_dropout, relu_dropout,
|
||||
preprocess_cmd, postprocess_cmd, weight_sharing, enc_inputs)
|
||||
|
||||
predict = wrap_decoder(
|
||||
trg_vocab_size,
|
||||
|
@ -650,21 +641,21 @@ def transformer(src_vocab_size,
|
|||
|
||||
|
||||
def wrap_encoder_forFeature(src_vocab_size,
|
||||
max_length,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd,
|
||||
postprocess_cmd,
|
||||
weight_sharing,
|
||||
enc_inputs=None,
|
||||
bos_idx=0):
|
||||
max_length,
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
d_value,
|
||||
d_model,
|
||||
d_inner_hid,
|
||||
prepostprocess_dropout,
|
||||
attention_dropout,
|
||||
relu_dropout,
|
||||
preprocess_cmd,
|
||||
postprocess_cmd,
|
||||
weight_sharing,
|
||||
enc_inputs=None,
|
||||
bos_idx=0):
|
||||
"""
|
||||
The wrapper assembles together all needed layers for the encoder.
|
||||
img, src_pos, src_slf_attn_bias = enc_inputs
|
||||
|
@ -676,8 +667,8 @@ def wrap_encoder_forFeature(src_vocab_size,
|
|||
conv_features, src_pos, src_slf_attn_bias = make_all_inputs(
|
||||
encoder_data_input_fields)
|
||||
else:
|
||||
conv_features, src_pos, src_slf_attn_bias = enc_inputs#
|
||||
b,t,c = conv_features.shape
|
||||
conv_features, src_pos, src_slf_attn_bias = enc_inputs #
|
||||
b, t, c = conv_features.shape
|
||||
#"""
|
||||
# insert cnn
|
||||
#"""
|
||||
|
@ -718,7 +709,7 @@ def wrap_encoder_forFeature(src_vocab_size,
|
|||
max_length,
|
||||
prepostprocess_dropout,
|
||||
bos_idx=bos_idx,
|
||||
word_emb_param_name=word_emb_param_names[0])
|
||||
word_emb_param_name="src_word_emb_table")
|
||||
|
||||
enc_output = encoder(
|
||||
enc_input,
|
||||
|
@ -736,6 +727,7 @@ def wrap_encoder_forFeature(src_vocab_size,
|
|||
postprocess_cmd, )
|
||||
return enc_output
|
||||
|
||||
|
||||
def wrap_encoder(src_vocab_size,
|
||||
max_length,
|
||||
n_layer,
|
||||
|
@ -762,7 +754,7 @@ def wrap_encoder(src_vocab_size,
|
|||
src_word, src_pos, src_slf_attn_bias = make_all_inputs(
|
||||
encoder_data_input_fields)
|
||||
else:
|
||||
src_word, src_pos, src_slf_attn_bias = enc_inputs#
|
||||
src_word, src_pos, src_slf_attn_bias = enc_inputs #
|
||||
#"""
|
||||
# insert cnn
|
||||
#"""
|
||||
|
@ -802,7 +794,7 @@ def wrap_encoder(src_vocab_size,
|
|||
max_length,
|
||||
prepostprocess_dropout,
|
||||
bos_idx=bos_idx,
|
||||
word_emb_param_name=word_emb_param_names[0])
|
||||
word_emb_param_name="src_word_emb_table")
|
||||
|
||||
enc_output = encoder(
|
||||
enc_input,
|
||||
|
@ -858,8 +850,8 @@ def wrap_decoder(trg_vocab_size,
|
|||
max_length,
|
||||
prepostprocess_dropout,
|
||||
bos_idx=bos_idx,
|
||||
word_emb_param_name=word_emb_param_names[0]
|
||||
if weight_sharing else word_emb_param_names[1])
|
||||
word_emb_param_name="src_word_emb_table"
|
||||
if weight_sharing else "trg_word_emb_table")
|
||||
dec_output = decoder(
|
||||
dec_input,
|
||||
enc_output,
|
||||
|
@ -886,7 +878,7 @@ def wrap_decoder(trg_vocab_size,
|
|||
predict = layers.matmul(
|
||||
x=dec_output,
|
||||
y=fluid.default_main_program().global_block().var(
|
||||
word_emb_param_names[0]),
|
||||
"trg_word_emb_table"),
|
||||
transpose_y=True)
|
||||
else:
|
||||
predict = layers.fc(input=dec_output,
|
||||
|
@ -931,12 +923,13 @@ def fast_decode(src_vocab_size,
|
|||
|
||||
enc_inputs_len = len(encoder_data_input_fields)
|
||||
dec_inputs_len = len(fast_decoder_data_input_fields)
|
||||
enc_inputs = all_inputs[0:enc_inputs_len]#enc_inputs tensor
|
||||
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len + dec_inputs_len]#dec_inputs tensor
|
||||
enc_inputs = all_inputs[0:enc_inputs_len] #enc_inputs tensor
|
||||
dec_inputs = all_inputs[enc_inputs_len:enc_inputs_len +
|
||||
dec_inputs_len] #dec_inputs tensor
|
||||
|
||||
enc_output = wrap_encoder(
|
||||
src_vocab_size,
|
||||
ModelHyperParams.src_seq_len,##to do !!!!!????
|
||||
64, ##to do !!!!!????
|
||||
n_layer,
|
||||
n_head,
|
||||
d_key,
|
||||
|
|
|
@ -75,7 +75,8 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
preds_lod = outs[0].lod()[0]
|
||||
labels, labels_lod = convert_rec_label_to_lod(label_list)
|
||||
acc, acc_num, sample_num = cal_predicts_accuracy(
|
||||
char_ops, preds, preds_lod, labels, labels_lod, is_remove_duplicate)
|
||||
char_ops, preds, preds_lod, labels, labels_lod,
|
||||
is_remove_duplicate)
|
||||
else:
|
||||
encoder_word_pos_list = []
|
||||
gsrm_word_pos_list = []
|
||||
|
@ -89,10 +90,14 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
|
||||
img_list = np.concatenate(img_list, axis=0)
|
||||
label_list = np.concatenate(label_list, axis=0)
|
||||
encoder_word_pos_list = np.concatenate(encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
encoder_word_pos_list = np.concatenate(
|
||||
encoder_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_word_pos_list = np.concatenate(
|
||||
gsrm_word_pos_list, axis=0).astype(np.int64)
|
||||
gsrm_slf_attn_bias1_list = np.concatenate(
|
||||
gsrm_slf_attn_bias1_list, axis=0).astype(np.float32)
|
||||
gsrm_slf_attn_bias2_list = np.concatenate(
|
||||
gsrm_slf_attn_bias2_list, axis=0).astype(np.float32)
|
||||
|
||||
labels = label_list
|
||||
|
||||
|
@ -108,7 +113,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
|
||||
total_acc_num += acc_num
|
||||
total_sample_num += sample_num
|
||||
logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
#logger.info("eval batch id: {}, acc: {}".format(total_batch_num, acc))
|
||||
total_batch_num += 1
|
||||
avg_acc = total_acc_num * 1.0 / total_sample_num
|
||||
metrics = {'avg_acc': avg_acc, "total_acc_num": total_acc_num, \
|
||||
|
|
|
@ -34,6 +34,7 @@ from ppocr.utils.save_load import save_model
|
|||
import numpy as np
|
||||
from ppocr.utils.character import cal_predicts_accuracy, cal_predicts_accuracy_srn, CharacterOps
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
|
@ -196,10 +197,13 @@ def build(config, main_prog, startup_prog, mode):
|
|||
if config['Global']["loss_type"] == 'srn':
|
||||
model_average = fluid.optimizer.ModelAverage(
|
||||
config['Global']['average_window'],
|
||||
min_average_window=config['Global']['min_average_window'],
|
||||
max_average_window=config['Global']['max_average_window'])
|
||||
min_average_window=config['Global'][
|
||||
'min_average_window'],
|
||||
max_average_window=config['Global'][
|
||||
'max_average_window'])
|
||||
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,model_average)
|
||||
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name,
|
||||
model_average)
|
||||
|
||||
|
||||
def build_export(config, main_prog, startup_prog):
|
||||
|
@ -398,6 +402,7 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
save_model(train_info_dict['train_program'], save_path)
|
||||
return
|
||||
|
||||
|
||||
def preprocess():
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
|
@ -409,8 +414,8 @@ def preprocess():
|
|||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
/workspace/PaddleOCR/train_data/
|
Loading…
Reference in New Issue