add anno for rec
This commit is contained in:
parent
8b64f4c2ea
commit
fc512a84e3
|
@ -25,6 +25,12 @@ from copy import deepcopy
|
|||
|
||||
|
||||
class RecModel(object):
|
||||
"""
|
||||
Rec model architecture
|
||||
Args:
|
||||
params(object): Params from yaml file and settings from command line
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
super(RecModel, self).__init__()
|
||||
global_params = params['Global']
|
||||
|
@ -64,6 +70,12 @@ class RecModel(object):
|
|||
self.num_heads = None
|
||||
|
||||
def create_feed(self, mode):
|
||||
"""
|
||||
Create feed dict and DataLoader object
|
||||
Args:
|
||||
mode(str): runtime mode, can be "train", "eval" or "test"
|
||||
Return: image, labels, loader
|
||||
"""
|
||||
image_shape = deepcopy(self.image_shape)
|
||||
image_shape.insert(0, -1)
|
||||
if mode == "train":
|
||||
|
@ -189,9 +201,12 @@ class RecModel(object):
|
|||
inputs = image
|
||||
else:
|
||||
inputs = self.tps(image)
|
||||
# backbone
|
||||
conv_feas = self.backbone(inputs)
|
||||
# predict
|
||||
predicts = self.head(conv_feas, labels, mode)
|
||||
decoded_out = predicts['decoded_out']
|
||||
# loss
|
||||
if mode == "train":
|
||||
loss = self.loss(predicts, labels)
|
||||
if self.loss_type == "attention":
|
||||
|
@ -211,7 +226,7 @@ class RecModel(object):
|
|||
outputs = {'total_loss':loss, 'decoded_out':\
|
||||
decoded_out, 'label':label}
|
||||
return loader, outputs
|
||||
|
||||
# export_model
|
||||
elif mode == "export":
|
||||
predict = predicts['predict']
|
||||
if self.loss_type == "ctc":
|
||||
|
@ -225,6 +240,7 @@ class RecModel(object):
|
|||
]
|
||||
|
||||
return [image, {'decoded_out': decoded_out, 'predicts': predict}]
|
||||
# eval or test
|
||||
else:
|
||||
predict = predicts['predict']
|
||||
if self.loss_type == "ctc":
|
||||
|
|
|
@ -27,6 +27,12 @@ import numpy as np
|
|||
|
||||
|
||||
class CTCPredict(object):
|
||||
"""
|
||||
CTC predict
|
||||
Args:
|
||||
params(object): Params from yaml file and settings from command line
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
super(CTCPredict, self).__init__()
|
||||
self.char_num = params['char_num']
|
||||
|
|
|
@ -33,6 +33,7 @@ class AttentionLoss(object):
|
|||
predict = predicts['predict']
|
||||
label_out = labels['label_out']
|
||||
label_out = fluid.layers.cast(x=label_out, dtype='int64')
|
||||
# calculate attention loss
|
||||
cost = fluid.layers.cross_entropy(input=predict, label=label_out)
|
||||
sum_cost = fluid.layers.reduce_sum(cost)
|
||||
return sum_cost
|
||||
|
|
|
@ -30,6 +30,7 @@ class CTCLoss(object):
|
|||
def __call__(self, predicts, labels):
|
||||
predict = predicts['predict']
|
||||
label = labels['label']
|
||||
# calculate ctc loss
|
||||
cost = fluid.layers.warpctc(
|
||||
input=predict, label=label, blank=self.char_num, norm_by_times=True)
|
||||
sum_cost = fluid.layers.reduce_sum(cost)
|
||||
|
|
|
@ -20,15 +20,21 @@ import sys
|
|||
|
||||
|
||||
class CharacterOps(object):
|
||||
""" Convert between text-label and text-index """
|
||||
"""
|
||||
Convert between text-label and text-index
|
||||
Args:
|
||||
config: config from yaml file
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
self.max_text_len = config['max_text_length']
|
||||
# use the default dictionary(36 char)
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
# use the custom dictionary
|
||||
elif self.character_type in [
|
||||
"ch", 'japan', 'korean', 'french', 'german'
|
||||
]:
|
||||
|
@ -55,25 +61,27 @@ class CharacterOps(object):
|
|||
"Nonsupport type of the character: {}".format(self.character_str)
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
# add start and end str for attention
|
||||
if self.loss_type == "attention":
|
||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
||||
elif self.loss_type == "srn":
|
||||
dict_character = dict_character + [self.beg_str, self.end_str]
|
||||
# create char dict
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def encode(self, text):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
"""
|
||||
convert text-label into text-index.
|
||||
Args:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
||||
output:
|
||||
Return:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
# Ignore capital
|
||||
if self.character_type == "en":
|
||||
text = text.lower()
|
||||
|
||||
|
@ -86,7 +94,15 @@ class CharacterOps(object):
|
|||
return text
|
||||
|
||||
def decode(self, text_index, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
"""
|
||||
convert text-index into text-label.
|
||||
Args:
|
||||
text_index: text index for each image
|
||||
is_remove_duplicate: Whether to remove duplicate characters,
|
||||
The default is False
|
||||
Return:
|
||||
text: text label
|
||||
"""
|
||||
char_list = []
|
||||
char_num = self.get_char_num()
|
||||
|
||||
|
@ -108,6 +124,9 @@ class CharacterOps(object):
|
|||
return text
|
||||
|
||||
def get_char_num(self):
|
||||
"""
|
||||
Get character num
|
||||
"""
|
||||
return len(self.character)
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
|
@ -132,6 +151,21 @@ def cal_predicts_accuracy(char_ops,
|
|||
labels,
|
||||
labels_lod,
|
||||
is_remove_duplicate=False):
|
||||
"""
|
||||
Calculate predicts accrarcy
|
||||
Args:
|
||||
char_ops: CharacterOps
|
||||
preds: preds result,text index
|
||||
preds_lod: lod tensor of preds
|
||||
labels: label of input image, text index
|
||||
labels_lod: lod tensor of label
|
||||
is_remove_duplicate: Whether to remove duplicate characters,
|
||||
The default is False
|
||||
Return:
|
||||
acc: The accuracy of test set
|
||||
acc_num: The correct number of samples predicted
|
||||
img_num: The total sample number of the test set
|
||||
"""
|
||||
acc_num = 0
|
||||
img_num = 0
|
||||
for ino in range(len(labels_lod) - 1):
|
||||
|
@ -189,6 +223,14 @@ def cal_predicts_accuracy_srn(char_ops,
|
|||
|
||||
|
||||
def convert_rec_attention_infer_res(preds):
|
||||
"""
|
||||
Convert recognition attention predict result with lod information
|
||||
Args:
|
||||
preds: the output of the model
|
||||
Return:
|
||||
convert_ids: A 1-D Tensor represents all the predicted results.
|
||||
target_lod: The lod information of the predicted results
|
||||
"""
|
||||
img_num = preds.shape[0]
|
||||
target_lod = [0]
|
||||
convert_ids = []
|
||||
|
|
|
@ -122,7 +122,9 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
|
|||
|
||||
|
||||
def test_rec_benchmark(exe, config, eval_info_dict):
|
||||
" Evaluate lmdb dataset "
|
||||
"""
|
||||
eval rec benchmark
|
||||
"""
|
||||
eval_data_list = ['IIIT5k_3000', 'SVT', 'IC03_860', 'IC03_867', \
|
||||
'IC13_857', 'IC13_1015', 'IC15_1811', 'IC15_2077', 'SVTP', 'CUTE80']
|
||||
eval_data_dir = config['TestReader']['lmdb_sets_dir']
|
||||
|
|
|
@ -150,19 +150,20 @@ def check_gpu(use_gpu):
|
|||
def build(config, main_prog, startup_prog, mode):
|
||||
"""
|
||||
Build a program using a model and an optimizer
|
||||
1. create feeds
|
||||
2. create a dataloader
|
||||
3. create a model
|
||||
4. create fetchs
|
||||
5. create an optimizer
|
||||
1. create a dataloader
|
||||
2. create a model
|
||||
3. create fetchs
|
||||
4. create an optimizer
|
||||
Args:
|
||||
config(dict): config
|
||||
main_prog(): main program
|
||||
startup_prog(): startup program
|
||||
is_train(bool): train or valid
|
||||
mode(str): train or valid
|
||||
Returns:
|
||||
dataloader(): a bridge between the model and the data
|
||||
fetchs(dict): dict of model outputs(included loss and measures)
|
||||
fetch_name_list(dict): dict of model outputs(included loss and measures)
|
||||
fetch_varname_list(list): list of outputs' varname
|
||||
opt_loss_name(str): name of loss
|
||||
"""
|
||||
with fluid.program_guard(main_prog, startup_prog):
|
||||
with fluid.unique_name.guard():
|
||||
|
@ -257,9 +258,14 @@ def train_eval_det_run(config,
|
|||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_slim=None):
|
||||
'''
|
||||
main program of evaluation for detection
|
||||
'''
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss for detection
|
||||
Args:
|
||||
config: config
|
||||
exe:
|
||||
train_info_dict: information dict for training
|
||||
eval_info_dict: information dict for evaluation
|
||||
"""
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
|
@ -376,9 +382,14 @@ def train_eval_rec_run(config,
|
|||
train_info_dict,
|
||||
eval_info_dict,
|
||||
is_slim=None):
|
||||
'''
|
||||
main program of evaluation for recognition
|
||||
'''
|
||||
"""
|
||||
Feed data to the model and fetch the measures and loss for recognition
|
||||
Args:
|
||||
config: config
|
||||
exe:
|
||||
train_info_dict: information dict for training
|
||||
eval_info_dict: information dict for evaluation
|
||||
"""
|
||||
train_batch_id = 0
|
||||
log_smooth_window = config['Global']['log_smooth_window']
|
||||
epoch_num = config['Global']['epoch_num']
|
||||
|
|
Loading…
Reference in New Issue