Merge pull request #848 from tink2123/add_anno

add comments for rec
This commit is contained in:
dyning 2020-09-27 11:27:19 +08:00 committed by GitHub
commit 2b16323080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 103 additions and 24 deletions

View File

@ -25,6 +25,12 @@ from copy import deepcopy
class RecModel(object):
"""
Rec model architecture
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params):
super(RecModel, self).__init__()
global_params = params['Global']
@ -64,6 +70,12 @@ class RecModel(object):
self.num_heads = None
def create_feed(self, mode):
"""
Create feed dict and DataLoader object
Args:
mode(str): runtime mode, can be "train", "eval" or "test"
Return: image, labels, loader
"""
image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1)
if mode == "train":
@ -189,9 +201,12 @@ class RecModel(object):
inputs = image
else:
inputs = self.tps(image)
# backbone
conv_feas = self.backbone(inputs)
# predict
predicts = self.head(conv_feas, labels, mode)
decoded_out = predicts['decoded_out']
# loss
if mode == "train":
loss = self.loss(predicts, labels)
if self.loss_type == "attention":
@ -211,7 +226,7 @@ class RecModel(object):
outputs = {'total_loss':loss, 'decoded_out':\
decoded_out, 'label':label}
return loader, outputs
# export_model
elif mode == "export":
predict = predicts['predict']
if self.loss_type == "ctc":
@ -225,6 +240,7 @@ class RecModel(object):
]
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
# eval or test
else:
predict = predicts['predict']
if self.loss_type == "ctc":

View File

@ -27,6 +27,12 @@ import numpy as np
class CTCPredict(object):
"""
CTC predict
Args:
params(object): Params from yaml file and settings from command line
"""
def __init__(self, params):
super(CTCPredict, self).__init__()
self.char_num = params['char_num']

View File

@ -33,6 +33,7 @@ class AttentionLoss(object):
predict = predicts['predict']
label_out = labels['label_out']
label_out = fluid.layers.cast(x=label_out, dtype='int64')
# calculate attention loss
cost = fluid.layers.cross_entropy(input=predict, label=label_out)
sum_cost = fluid.layers.reduce_sum(cost)
return sum_cost

View File

@ -30,6 +30,7 @@ class CTCLoss(object):
def __call__(self, predicts, labels):
predict = predicts['predict']
label = labels['label']
# calculate ctc loss
cost = fluid.layers.warpctc(
input=predict, label=label, blank=self.char_num, norm_by_times=True)
sum_cost = fluid.layers.reduce_sum(cost)

View File

@ -20,15 +20,21 @@ import sys
class CharacterOps(object):
""" Convert between text-label and text-index """
"""
Convert between text-label and text-index
Args:
config: config from yaml file
"""
def __init__(self, config):
self.character_type = config['character_type']
self.loss_type = config['loss_type']
self.max_text_len = config['max_text_length']
# use the default dictionary(36 char)
if self.character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
# use the custom dictionary
elif self.character_type in [
"ch", 'japan', 'korean', 'french', 'german'
]:
@ -55,25 +61,27 @@ class CharacterOps(object):
"Nonsupport type of the character: {}".format(self.character_str)
self.beg_str = "sos"
self.end_str = "eos"
# add start and end str for attention
if self.loss_type == "attention":
dict_character = [self.beg_str, self.end_str] + dict_character
elif self.loss_type == "srn":
dict_character = dict_character + [self.beg_str, self.end_str]
# create char dict
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character
def encode(self, text):
"""convert text-label into text-index.
input:
"""
convert text-label into text-index.
Args:
text: text labels of each image. [batch_size]
output:
Return:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
# Ignore capital
if self.character_type == "en":
text = text.lower()
@ -86,7 +94,15 @@ class CharacterOps(object):
return text
def decode(self, text_index, is_remove_duplicate=False):
""" convert text-index into text-label. """
"""
convert text-index into text-label.
Args:
text_index: text index for each image
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
text: text label
"""
char_list = []
char_num = self.get_char_num()
@ -108,6 +124,9 @@ class CharacterOps(object):
return text
def get_char_num(self):
"""
Get character num
"""
return len(self.character)
def get_beg_end_flag_idx(self, beg_or_end):
@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops,
labels,
labels_lod,
is_remove_duplicate=False):
"""
Calculate prediction accuracy
Args:
char_ops: CharacterOps
preds: preds result,text index
preds_lod: lod tensor of preds
labels: label of input image, text index
labels_lod: lod tensor of label
is_remove_duplicate: Whether to remove duplicate characters,
The default is False
Return:
acc: The accuracy of test set
acc_num: The correct number of samples predicted
img_num: The total sample number of the test set
"""
acc_num = 0
img_num = 0
for ino in range(len(labels_lod) - 1):
@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops,
def convert_rec_attention_infer_res(preds):
"""
Convert recognition attention predict result with lod information
Args:
preds: the output of the model
Return:
convert_ids: A 1-D Tensor represents all the predicted results.
target_lod: The lod information of the predicted results
"""
img_num = preds.shape[0]
target_lod = [0]
convert_ids = []

View File

@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
def test_rec_benchmark(exe, config, eval_info_dict):
" Evaluate lmdb dataset "
"""
eval rec benchmark
"""
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
eval_data_dir = config['TestReader']['lmdb_sets_dir']

View File

@ -150,19 +150,20 @@ def check_gpu(use_gpu):
def build(config, main_prog, startup_prog, mode):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
1. create a dataloader
2. create a model
3. create fetches
4. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
mode(str): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
fetch_name_list(dict): dict of model outputs(included loss and measures)
fetch_varname_list(list): list of outputs' varname
opt_loss_name(str): name of loss
"""
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
@ -207,8 +208,8 @@ def build_export(config, main_prog, startup_prog):
Build input and output for exporting a checkpoints model to an inference model
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
main_prog: main program
startup_prog: startup program
Returns:
feeded_var_names(list[str]): var names of input for exported inference model
target_vars(list[Variable]): output vars for exported inference model
@ -257,9 +258,14 @@ def train_eval_det_run(config,
train_info_dict,
eval_info_dict,
is_slim=None):
'''
main program of evaluation for detection
'''
"""
Feed data to the model and fetch the measures and loss for detection
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']
@ -376,9 +382,14 @@ def train_eval_rec_run(config,
train_info_dict,
eval_info_dict,
is_slim=None):
'''
main program of evaluation for recognition
'''
"""
Feed data to the model and fetch the measures and loss for recognition
Args:
config: config
exe:
train_info_dict: information dict for training
eval_info_dict: information dict for evaluation
"""
train_batch_id = 0
log_smooth_window = config['Global']['log_smooth_window']
epoch_num = config['Global']['epoch_num']