add train code for table
This commit is contained in:
parent
e93735a2ef
commit
7c8b2c8d19
|
@ -0,0 +1,116 @@
|
||||||
|
Global:
|
||||||
|
use_gpu: true
|
||||||
|
epoch_num: 40
|
||||||
|
log_smooth_window: 20
|
||||||
|
print_batch_step: 5
|
||||||
|
save_model_dir: ./output/table_mv3/
|
||||||
|
save_epoch_step: 3
|
||||||
|
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||||
|
eval_batch_step: [0, 400]
|
||||||
|
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||||
|
cal_metric_during_train: True
|
||||||
|
pretrained_model:
|
||||||
|
checkpoints:
|
||||||
|
save_inference_dir:
|
||||||
|
use_visualdl: False
|
||||||
|
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||||
|
# for data or label process
|
||||||
|
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||||
|
character_type: en
|
||||||
|
max_text_length: 100
|
||||||
|
max_elem_length: 800
|
||||||
|
max_cell_num: 500
|
||||||
|
infer_mode: False
|
||||||
|
process_total_num: 0
|
||||||
|
process_cut_num: 0
|
||||||
|
|
||||||
|
Optimizer:
|
||||||
|
name: Adam
|
||||||
|
beta1: 0.9
|
||||||
|
beta2: 0.999
|
||||||
|
clip_norm: 5.0
|
||||||
|
lr:
|
||||||
|
learning_rate: 0.0001
|
||||||
|
regularizer:
|
||||||
|
name: 'L2'
|
||||||
|
factor: 0.00000
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
model_type: table
|
||||||
|
algorithm: TableAttn
|
||||||
|
Backbone:
|
||||||
|
name: MobileNetV3
|
||||||
|
scale: 1.0
|
||||||
|
model_name: large
|
||||||
|
Head:
|
||||||
|
name: TableAttentionHead # AttentionHead
|
||||||
|
hidden_size: 256 #
|
||||||
|
l2_decay: 0.00001
|
||||||
|
# loc_type: 1
|
||||||
|
loc_type: 2
|
||||||
|
|
||||||
|
Loss:
|
||||||
|
name: TableAttentionLoss
|
||||||
|
structure_weight: 100.0
|
||||||
|
loc_weight: 10000.0
|
||||||
|
|
||||||
|
PostProcess:
|
||||||
|
name: TableLabelDecode
|
||||||
|
|
||||||
|
Metric:
|
||||||
|
name: TableMetric
|
||||||
|
main_indicator: acc
|
||||||
|
|
||||||
|
Train:
|
||||||
|
dataset:
|
||||||
|
name: PubTabDataSet
|
||||||
|
data_dir: train_data/table/pubtabnet/train/
|
||||||
|
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- ResizeTableImage:
|
||||||
|
max_len: 488
|
||||||
|
- TableLabelEncode:
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- PaddingTableImage:
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||||
|
loader:
|
||||||
|
shuffle: True
|
||||||
|
batch_size_per_card: 32
|
||||||
|
drop_last: True
|
||||||
|
num_workers: 4
|
||||||
|
|
||||||
|
Eval:
|
||||||
|
dataset:
|
||||||
|
name: PubTabDataSet
|
||||||
|
data_dir: train_data/table/pubtabnet/val/
|
||||||
|
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
|
||||||
|
transforms:
|
||||||
|
- DecodeImage: # load image
|
||||||
|
img_mode: BGR
|
||||||
|
channel_first: False
|
||||||
|
- ResizeTableImage:
|
||||||
|
max_len: 488
|
||||||
|
- TableLabelEncode:
|
||||||
|
- NormalizeImage:
|
||||||
|
scale: 1./255.
|
||||||
|
mean: [0.485, 0.456, 0.406]
|
||||||
|
std: [0.229, 0.224, 0.225]
|
||||||
|
order: 'hwc'
|
||||||
|
- PaddingTableImage:
|
||||||
|
- ToCHWImage:
|
||||||
|
- KeepKeys:
|
||||||
|
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||||
|
loader:
|
||||||
|
shuffle: False
|
||||||
|
drop_last: False
|
||||||
|
batch_size_per_card: 16
|
||||||
|
num_workers: 4
|
|
@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
|
||||||
from ppocr.data.simple_dataset import SimpleDataSet
|
from ppocr.data.simple_dataset import SimpleDataSet
|
||||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||||
from ppocr.data.pgnet_dataset import PGDataSet
|
from ppocr.data.pgnet_dataset import PGDataSet
|
||||||
|
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||||
|
|
||||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||||
|
|
||||||
|
@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
||||||
def build_dataloader(config, mode, device, logger, seed=None):
|
def build_dataloader(config, mode, device, logger, seed=None):
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
||||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
|
||||||
module_name = config[mode]['dataset']['name']
|
module_name = config[mode]['dataset']['name']
|
||||||
assert module_name in support_dict, Exception(
|
assert module_name in support_dict, Exception(
|
||||||
'DataSet only support {}'.format(support_dict))
|
'DataSet only support {}'.format(support_dict))
|
||||||
|
|
|
@ -351,3 +351,182 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
||||||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||||
% beg_or_end
|
% beg_or_end
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
class TableLabelEncode(object):
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
def __init__(self,
|
||||||
|
max_text_length,
|
||||||
|
max_elem_length,
|
||||||
|
max_cell_num,
|
||||||
|
character_dict_path,
|
||||||
|
span_weight = 1.0,
|
||||||
|
**kwargs):
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
self.max_elem_length = max_elem_length
|
||||||
|
self.max_cell_num = max_cell_num
|
||||||
|
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||||
|
list_character = self.add_special_char(list_character)
|
||||||
|
list_elem = self.add_special_char(list_elem)
|
||||||
|
self.dict_character = {}
|
||||||
|
for i, char in enumerate(list_character):
|
||||||
|
self.dict_character[char] = i
|
||||||
|
self.dict_elem = {}
|
||||||
|
for i, elem in enumerate(list_elem):
|
||||||
|
self.dict_elem[elem] = i
|
||||||
|
self.span_weight = span_weight
|
||||||
|
|
||||||
|
def load_char_elem_dict(self, character_dict_path):
|
||||||
|
list_character = []
|
||||||
|
list_elem = []
|
||||||
|
with open(character_dict_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
substr = lines[0].decode('utf-8').strip("\n").split("\t")
|
||||||
|
character_num = int(substr[0])
|
||||||
|
elem_num = int(substr[1])
|
||||||
|
for cno in range(1, 1+character_num):
|
||||||
|
character = lines[cno].decode('utf-8').strip("\n")
|
||||||
|
list_character.append(character)
|
||||||
|
for eno in range(1+character_num, 1+character_num+elem_num):
|
||||||
|
elem = lines[eno].decode('utf-8').strip("\n")
|
||||||
|
list_elem.append(elem)
|
||||||
|
return list_character, list_elem
|
||||||
|
|
||||||
|
def add_special_char(self, list_character):
|
||||||
|
self.beg_str = "sos"
|
||||||
|
self.end_str = "eos"
|
||||||
|
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||||
|
return list_character
|
||||||
|
|
||||||
|
def get_span_idx_list(self):
|
||||||
|
span_idx_list = []
|
||||||
|
for elem in self.dict_elem:
|
||||||
|
if 'span' in elem:
|
||||||
|
span_idx_list.append(self.dict_elem[elem])
|
||||||
|
return span_idx_list
|
||||||
|
|
||||||
|
def __call__(self, data):
|
||||||
|
cells = data['cells']
|
||||||
|
structure = data['structure']['tokens']
|
||||||
|
structure = self.encode(structure, 'elem')
|
||||||
|
if structure is None:
|
||||||
|
return None
|
||||||
|
elem_num = len(structure)
|
||||||
|
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||||
|
# structure = [0] + structure + [0]
|
||||||
|
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
||||||
|
structure = np.array(structure)
|
||||||
|
data['structure'] = structure
|
||||||
|
elem_char_idx1 = self.dict_elem['<td>']
|
||||||
|
elem_char_idx2 = self.dict_elem['<td']
|
||||||
|
span_idx_list = self.get_span_idx_list()
|
||||||
|
td_idx_list = np.logical_or(structure == elem_char_idx1, structure == elem_char_idx2)
|
||||||
|
td_idx_list = np.where(td_idx_list)[0]
|
||||||
|
|
||||||
|
structure_mask = np.ones((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||||
|
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||||
|
bbox_list_mask = np.zeros((self.max_elem_length + 2, 1), dtype=np.float32)
|
||||||
|
img_height, img_width, img_ch = data['image'].shape
|
||||||
|
if len(span_idx_list) > 0:
|
||||||
|
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||||
|
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
||||||
|
for cno in range(len(cells)):
|
||||||
|
if 'bbox' in cells[cno]:
|
||||||
|
bbox = cells[cno]['bbox'].copy()
|
||||||
|
bbox[0] = bbox[0] * 1.0 / img_width
|
||||||
|
bbox[1] = bbox[1] * 1.0 / img_height
|
||||||
|
bbox[2] = bbox[2] * 1.0 / img_width
|
||||||
|
bbox[3] = bbox[3] * 1.0 / img_height
|
||||||
|
td_idx = td_idx_list[cno]
|
||||||
|
bbox_list[td_idx] = bbox
|
||||||
|
bbox_list_mask[td_idx] = 1.0
|
||||||
|
cand_span_idx = td_idx + 1
|
||||||
|
if cand_span_idx < (self.max_elem_length + 2):
|
||||||
|
if structure[cand_span_idx] in span_idx_list:
|
||||||
|
structure_mask[cand_span_idx] = span_weight
|
||||||
|
# structure_mask[td_idx] = self.span_weight
|
||||||
|
# structure_mask[cand_span_idx] = self.span_weight
|
||||||
|
|
||||||
|
data['bbox_list'] = bbox_list
|
||||||
|
data['bbox_list_mask'] = bbox_list_mask
|
||||||
|
data['structure_mask'] = structure_mask
|
||||||
|
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||||
|
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||||
|
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||||
|
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||||
|
data['sp_tokens'] = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||||
|
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||||
|
self.max_elem_length, self.max_cell_num, elem_num])
|
||||||
|
return data
|
||||||
|
|
||||||
|
########
|
||||||
|
# for char decode
|
||||||
|
# cell_list = []
|
||||||
|
# for cell in cells:
|
||||||
|
# char_list = cell['tokens']
|
||||||
|
# cell = self.encode(char_list, 'char')
|
||||||
|
# if cell is None:
|
||||||
|
# return None
|
||||||
|
# cell = [0] + cell + [len(self.dict_character) - 1]
|
||||||
|
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
|
||||||
|
# cell_list.append(cell)
|
||||||
|
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
|
||||||
|
# cell_list = np.array(cell_list)
|
||||||
|
# cell_list_padding[0:cell_list.shape[0]] = cell_list
|
||||||
|
# data['cells'] = cell_list_padding
|
||||||
|
# return data
|
||||||
|
|
||||||
|
def encode(self, text, char_or_elem):
|
||||||
|
"""convert text-label into text-index.
|
||||||
|
"""
|
||||||
|
if char_or_elem == "char":
|
||||||
|
max_len = self.max_text_length
|
||||||
|
current_dict = self.dict_character
|
||||||
|
else:
|
||||||
|
max_len = self.max_elem_length
|
||||||
|
current_dict = self.dict_elem
|
||||||
|
if len(text) > max_len:
|
||||||
|
return None
|
||||||
|
if len(text) == 0:
|
||||||
|
if char_or_elem == "char":
|
||||||
|
return [self.dict_character['space']]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
text_list = []
|
||||||
|
for char in text:
|
||||||
|
if char not in current_dict:
|
||||||
|
return None
|
||||||
|
text_list.append(current_dict[char])
|
||||||
|
if len(text_list) == 0:
|
||||||
|
if char_or_elem == "char":
|
||||||
|
return [self.dict_character['space']]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return text_list
|
||||||
|
|
||||||
|
def get_ignored_tokens(self, char_or_elem):
|
||||||
|
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||||
|
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||||
|
return [beg_idx, end_idx]
|
||||||
|
|
||||||
|
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||||
|
if char_or_elem == "char":
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict_character[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict_character[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||||
|
% beg_or_end
|
||||||
|
elif char_or_elem == "elem":
|
||||||
|
if beg_or_end == "beg":
|
||||||
|
idx = np.array(self.dict_elem[self.beg_str])
|
||||||
|
elif beg_or_end == "end":
|
||||||
|
idx = np.array(self.dict_elem[self.end_str])
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||||
|
% beg_or_end
|
||||||
|
else:
|
||||||
|
assert False, "Unsupport type %s in char_or_elem" \
|
||||||
|
% char_or_elem
|
||||||
|
return idx
|
||||||
|
|
|
@ -0,0 +1,125 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from paddle.io import Dataset
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
class PubTabDataSet(Dataset):
|
||||||
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
|
super(PubTabDataSet, self).__init__()
|
||||||
|
self.logger = logger
|
||||||
|
|
||||||
|
global_config = config['Global']
|
||||||
|
dataset_config = config[mode]['dataset']
|
||||||
|
loader_config = config[mode]['loader']
|
||||||
|
|
||||||
|
label_file_path = dataset_config.pop('label_file_path')
|
||||||
|
|
||||||
|
self.data_dir = dataset_config['data_dir']
|
||||||
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
self.do_hard_select = False
|
||||||
|
if 'hard_select' in loader_config:
|
||||||
|
self.do_hard_select = loader_config['hard_select']
|
||||||
|
self.hard_prob = loader_config['hard_prob']
|
||||||
|
if self.do_hard_select:
|
||||||
|
self.img_select_prob = self.load_hard_select_prob()
|
||||||
|
self.table_select_type = None
|
||||||
|
if 'table_select_type' in loader_config:
|
||||||
|
self.table_select_type = loader_config['table_select_type']
|
||||||
|
self.table_select_prob = loader_config['table_select_prob']
|
||||||
|
|
||||||
|
self.seed = seed
|
||||||
|
logger.info("Initialize indexs of datasets:%s" % label_file_path)
|
||||||
|
with open(label_file_path, "rb") as f:
|
||||||
|
self.data_lines = f.readlines()
|
||||||
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
|
if mode.lower() == "train":
|
||||||
|
self.shuffle_data_random()
|
||||||
|
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||||
|
|
||||||
|
def shuffle_data_random(self):
|
||||||
|
if self.do_shuffle:
|
||||||
|
random.seed(self.seed)
|
||||||
|
random.shuffle(self.data_lines)
|
||||||
|
return
|
||||||
|
|
||||||
|
def load_hard_select_prob(self):
|
||||||
|
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
|
||||||
|
img_select_prob = {}
|
||||||
|
with open(label_path, "rb") as fin:
|
||||||
|
lines = fin.readlines()
|
||||||
|
for lno in range(len(lines)):
|
||||||
|
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
|
||||||
|
img_name = substr[0].strip(":")
|
||||||
|
score = float(substr[1])
|
||||||
|
if score <= 0.8:
|
||||||
|
img_select_prob[img_name] = self.hard_prob[0]
|
||||||
|
elif score <= 0.98:
|
||||||
|
img_select_prob[img_name] = self.hard_prob[1]
|
||||||
|
else:
|
||||||
|
img_select_prob[img_name] = self.hard_prob[2]
|
||||||
|
return img_select_prob
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
try:
|
||||||
|
data_line = self.data_lines[idx]
|
||||||
|
data_line = data_line.decode('utf-8').strip("\n")
|
||||||
|
info = json.loads(data_line)
|
||||||
|
file_name = info['filename']
|
||||||
|
select_flag = True
|
||||||
|
if self.do_hard_select:
|
||||||
|
prob = self.img_select_prob[file_name]
|
||||||
|
if prob < random.uniform(0, 1):
|
||||||
|
select_flag = False
|
||||||
|
|
||||||
|
if self.table_select_type:
|
||||||
|
structure = info['html']['structure']['tokens'].copy()
|
||||||
|
structure_str = ''.join(structure)
|
||||||
|
table_type = "simple"
|
||||||
|
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||||
|
table_type = "complex"
|
||||||
|
# if self.table_select_type != table_type:
|
||||||
|
# select_flag = False
|
||||||
|
if table_type == "complex":
|
||||||
|
if self.table_select_prob < random.uniform(0, 1):
|
||||||
|
select_flag = False
|
||||||
|
|
||||||
|
if select_flag:
|
||||||
|
cells = info['html']['cells'].copy()
|
||||||
|
structure = info['html']['structure'].copy()
|
||||||
|
img_path = os.path.join(self.data_dir, file_name)
|
||||||
|
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
|
||||||
|
if not os.path.exists(img_path):
|
||||||
|
raise Exception("{} does not exist!".format(img_path))
|
||||||
|
with open(data['img_path'], 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
data['image'] = img
|
||||||
|
outs = transform(data, self.ops)
|
||||||
|
else:
|
||||||
|
outs = None
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(
|
||||||
|
"When parsing line {}, error happened with msg: {}".format(
|
||||||
|
data_line, e))
|
||||||
|
outs = None
|
||||||
|
if outs is None:
|
||||||
|
return self.__getitem__(np.random.randint(self.__len__()))
|
||||||
|
return outs
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data_idx_order_list)
|
|
@ -38,11 +38,13 @@ from .basic_loss import DistanceLoss
|
||||||
# combined loss function
|
# combined loss function
|
||||||
from .combined_loss import CombinedLoss
|
from .combined_loss import CombinedLoss
|
||||||
|
|
||||||
|
# table loss
|
||||||
|
from .table_att_loss import TableAttentionLoss
|
||||||
|
|
||||||
def build_loss(config):
|
def build_loss(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||||
'SRNLoss', 'PGLoss', 'CombinedLoss'
|
'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
|
||||||
]
|
]
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle import fluid
|
||||||
|
|
||||||
|
class TableAttentionLoss(nn.Layer):
|
||||||
|
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||||
|
super(TableAttentionLoss, self).__init__()
|
||||||
|
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||||
|
self.structure_weight = structure_weight
|
||||||
|
self.loc_weight = loc_weight
|
||||||
|
self.use_giou = use_giou
|
||||||
|
self.giou_weight = giou_weight
|
||||||
|
|
||||||
|
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||||
|
'''
|
||||||
|
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||||
|
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||||
|
:return: loss
|
||||||
|
'''
|
||||||
|
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||||
|
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||||
|
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||||
|
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||||
|
|
||||||
|
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||||
|
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||||
|
|
||||||
|
# overlap
|
||||||
|
inters = iw * ih
|
||||||
|
|
||||||
|
# union
|
||||||
|
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
|
||||||
|
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
|
||||||
|
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
|
||||||
|
|
||||||
|
# ious
|
||||||
|
ious = inters / uni
|
||||||
|
|
||||||
|
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||||
|
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||||
|
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||||
|
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||||
|
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||||
|
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||||
|
|
||||||
|
# enclose erea
|
||||||
|
enclose = ew * eh + eps
|
||||||
|
giou = ious - (enclose - uni) / enclose
|
||||||
|
|
||||||
|
loss = 1 - giou
|
||||||
|
|
||||||
|
if reduction == 'mean':
|
||||||
|
loss = paddle.mean(loss)
|
||||||
|
elif reduction == 'sum':
|
||||||
|
loss = paddle.sum(loss)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def forward(self, predicts, batch):
|
||||||
|
structure_probs = predicts['structure_probs']
|
||||||
|
structure_targets = batch[1].astype("int64")
|
||||||
|
structure_targets = structure_targets[:, 1:]
|
||||||
|
if len(batch) == 6:
|
||||||
|
structure_mask = batch[5].astype("int64")
|
||||||
|
structure_mask = structure_mask[:, 1:]
|
||||||
|
structure_mask = paddle.reshape(structure_mask, [-1])
|
||||||
|
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
|
||||||
|
structure_targets = paddle.reshape(structure_targets, [-1])
|
||||||
|
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||||
|
|
||||||
|
if len(batch) == 6:
|
||||||
|
structure_loss = structure_loss * structure_mask
|
||||||
|
|
||||||
|
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
|
||||||
|
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||||
|
|
||||||
|
loc_preds = predicts['loc_preds']
|
||||||
|
loc_targets = batch[2].astype("float32")
|
||||||
|
loc_targets_mask = batch[4].astype("float32")
|
||||||
|
loc_targets = loc_targets[:, 1:, :]
|
||||||
|
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||||
|
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
|
||||||
|
if self.use_giou:
|
||||||
|
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
|
||||||
|
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||||
|
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
|
||||||
|
else:
|
||||||
|
total_loss = structure_loss + loc_loss
|
||||||
|
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
|
|
@ -26,11 +26,11 @@ from .rec_metric import RecMetric
|
||||||
from .cls_metric import ClsMetric
|
from .cls_metric import ClsMetric
|
||||||
from .e2e_metric import E2EMetric
|
from .e2e_metric import E2EMetric
|
||||||
from .distillation_metric import DistillationMetric
|
from .distillation_metric import DistillationMetric
|
||||||
|
from .table_metric import TableMetric
|
||||||
|
|
||||||
def build_metric(config):
|
def build_metric(config):
|
||||||
support_dict = [
|
support_dict = [
|
||||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
|
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
|
||||||
]
|
]
|
||||||
|
|
||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import numpy as np
|
||||||
|
class TableMetric(object):
|
||||||
|
def __init__(self, main_indicator='acc', **kwargs):
|
||||||
|
self.main_indicator = main_indicator
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def __call__(self, pred, batch, *args, **kwargs):
|
||||||
|
structure_probs = pred['structure_probs'].numpy()
|
||||||
|
structure_labels = batch[1]
|
||||||
|
correct_num = 0
|
||||||
|
all_num = 0
|
||||||
|
structure_probs = np.argmax(structure_probs, axis=2)
|
||||||
|
structure_labels = structure_labels[:, 1:]
|
||||||
|
batch_size = structure_probs.shape[0]
|
||||||
|
for bno in range(batch_size):
|
||||||
|
all_num += 1
|
||||||
|
if (structure_probs[bno] == structure_labels[bno]).all():
|
||||||
|
correct_num += 1
|
||||||
|
self.correct_num += correct_num
|
||||||
|
self.all_num += all_num
|
||||||
|
return {
|
||||||
|
'acc': correct_num * 1.0 / all_num,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_metric(self):
|
||||||
|
"""
|
||||||
|
return metrics {
|
||||||
|
'acc': 0,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
acc = 1.0 * self.correct_num / self.all_num
|
||||||
|
self.reset()
|
||||||
|
return {'acc': acc}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.correct_num = 0
|
||||||
|
self.all_num = 0
|
|
@ -69,7 +69,7 @@ class BaseModel(nn.Layer):
|
||||||
|
|
||||||
self.return_all_feats = config.get("return_all_feats", False)
|
self.return_all_feats = config.get("return_all_feats", False)
|
||||||
|
|
||||||
def forward(self, x, data=None):
|
def forward(self, x, data=None, mode='Train'):
|
||||||
y = dict()
|
y = dict()
|
||||||
if self.use_transform:
|
if self.use_transform:
|
||||||
x = self.transform(x)
|
x = self.transform(x)
|
||||||
|
@ -81,7 +81,10 @@ class BaseModel(nn.Layer):
|
||||||
if data is None:
|
if data is None:
|
||||||
x = self.head(x)
|
x = self.head(x)
|
||||||
else:
|
else:
|
||||||
x = self.head(x, data)
|
if mode == 'Eval' or mode == 'Test':
|
||||||
|
x = self.head(x, targets=data, mode=mode)
|
||||||
|
else:
|
||||||
|
x = self.head(x, targets=data)
|
||||||
y["head_out"] = x
|
y["head_out"] = x
|
||||||
if self.return_all_feats:
|
if self.return_all_feats:
|
||||||
return y
|
return y
|
||||||
|
|
|
@ -29,6 +29,10 @@ def build_backbone(config, model_type):
|
||||||
elif model_type == 'e2e':
|
elif model_type == 'e2e':
|
||||||
from .e2e_resnet_vd_pg import ResNet
|
from .e2e_resnet_vd_pg import ResNet
|
||||||
support_dict = ['ResNet']
|
support_dict = ['ResNet']
|
||||||
|
elif model_type == "table":
|
||||||
|
from .table_resnet_vd import ResNet
|
||||||
|
from .table_mobilenet_v3 import MobileNetV3
|
||||||
|
support_dict = ['ResNet', 'MobileNetV3']
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,287 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
__all__ = ['MobileNetV3']
|
||||||
|
|
||||||
|
|
||||||
|
def make_divisible(v, divisor=8, min_value=None):
|
||||||
|
if min_value is None:
|
||||||
|
min_value = divisor
|
||||||
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||||
|
if new_v < 0.9 * v:
|
||||||
|
new_v += divisor
|
||||||
|
return new_v
|
||||||
|
|
||||||
|
|
||||||
|
class MobileNetV3(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels=3,
|
||||||
|
model_name='large',
|
||||||
|
scale=0.5,
|
||||||
|
disable_se=False,
|
||||||
|
**kwargs):
|
||||||
|
"""
|
||||||
|
the MobilenetV3 backbone network for detection module.
|
||||||
|
Args:
|
||||||
|
params(dict): the super parameters for build network
|
||||||
|
"""
|
||||||
|
super(MobileNetV3, self).__init__()
|
||||||
|
|
||||||
|
self.disable_se = disable_se
|
||||||
|
|
||||||
|
if model_name == "large":
|
||||||
|
cfg = [
|
||||||
|
# k, exp, c, se, nl, s,
|
||||||
|
[3, 16, 16, False, 'relu', 1],
|
||||||
|
[3, 64, 24, False, 'relu', 2],
|
||||||
|
[3, 72, 24, False, 'relu', 1],
|
||||||
|
[5, 72, 40, True, 'relu', 2],
|
||||||
|
[5, 120, 40, True, 'relu', 1],
|
||||||
|
[5, 120, 40, True, 'relu', 1],
|
||||||
|
[3, 240, 80, False, 'hardswish', 2],
|
||||||
|
[3, 200, 80, False, 'hardswish', 1],
|
||||||
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
|
[3, 184, 80, False, 'hardswish', 1],
|
||||||
|
[3, 480, 112, True, 'hardswish', 1],
|
||||||
|
[3, 672, 112, True, 'hardswish', 1],
|
||||||
|
[5, 672, 160, True, 'hardswish', 2],
|
||||||
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
|
[5, 960, 160, True, 'hardswish', 1],
|
||||||
|
]
|
||||||
|
cls_ch_squeeze = 960
|
||||||
|
elif model_name == "small":
|
||||||
|
cfg = [
|
||||||
|
# k, exp, c, se, nl, s,
|
||||||
|
[3, 16, 16, True, 'relu', 2],
|
||||||
|
[3, 72, 24, False, 'relu', 2],
|
||||||
|
[3, 88, 24, False, 'relu', 1],
|
||||||
|
[5, 96, 40, True, 'hardswish', 2],
|
||||||
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
|
[5, 240, 40, True, 'hardswish', 1],
|
||||||
|
[5, 120, 48, True, 'hardswish', 1],
|
||||||
|
[5, 144, 48, True, 'hardswish', 1],
|
||||||
|
[5, 288, 96, True, 'hardswish', 2],
|
||||||
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
|
[5, 576, 96, True, 'hardswish', 1],
|
||||||
|
]
|
||||||
|
cls_ch_squeeze = 576
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("mode[" + model_name +
|
||||||
|
"_model] is not implemented!")
|
||||||
|
|
||||||
|
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||||
|
assert scale in supported_scale, \
|
||||||
|
"supported scale are {} but input scale is {}".format(supported_scale, scale)
|
||||||
|
inplanes = 16
|
||||||
|
# conv1
|
||||||
|
self.conv = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=make_divisible(inplanes * scale),
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
padding=1,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act='hardswish',
|
||||||
|
name='conv1')
|
||||||
|
|
||||||
|
self.stages = []
|
||||||
|
self.out_channels = []
|
||||||
|
block_list = []
|
||||||
|
i = 0
|
||||||
|
inplanes = make_divisible(inplanes * scale)
|
||||||
|
for (k, exp, c, se, nl, s) in cfg:
|
||||||
|
se = se and not self.disable_se
|
||||||
|
start_idx = 2 if model_name == 'large' else 0
|
||||||
|
if s == 2 and i > start_idx:
|
||||||
|
self.out_channels.append(inplanes)
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
block_list = []
|
||||||
|
block_list.append(
|
||||||
|
ResidualUnit(
|
||||||
|
in_channels=inplanes,
|
||||||
|
mid_channels=make_divisible(scale * exp),
|
||||||
|
out_channels=make_divisible(scale * c),
|
||||||
|
kernel_size=k,
|
||||||
|
stride=s,
|
||||||
|
use_se=se,
|
||||||
|
act=nl,
|
||||||
|
name="conv" + str(i + 2)))
|
||||||
|
inplanes = make_divisible(scale * c)
|
||||||
|
i += 1
|
||||||
|
block_list.append(
|
||||||
|
ConvBNLayer(
|
||||||
|
in_channels=inplanes,
|
||||||
|
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act='hardswish',
|
||||||
|
name='conv_last'))
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||||
|
for i, stage in enumerate(self.stages):
|
||||||
|
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
out_list = []
|
||||||
|
for stage in self.stages:
|
||||||
|
x = stage(x)
|
||||||
|
out_list.append(x)
|
||||||
|
return out_list
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
groups=1,
|
||||||
|
if_act=True,
|
||||||
|
act=None,
|
||||||
|
name=None):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
self.if_act = if_act
|
||||||
|
self.act = act
|
||||||
|
self.conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + '_weights'),
|
||||||
|
bias_attr=False)
|
||||||
|
|
||||||
|
self.bn = nn.BatchNorm(
|
||||||
|
num_channels=out_channels,
|
||||||
|
act=None,
|
||||||
|
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||||
|
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||||
|
moving_mean_name=name + "_bn_mean",
|
||||||
|
moving_variance_name=name + "_bn_variance")
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv(x)
|
||||||
|
x = self.bn(x)
|
||||||
|
if self.if_act:
|
||||||
|
if self.act == "relu":
|
||||||
|
x = F.relu(x)
|
||||||
|
elif self.act == "hardswish":
|
||||||
|
x = F.hardswish(x)
|
||||||
|
else:
|
||||||
|
print("The activation function({}) is selected incorrectly.".
|
||||||
|
format(self.act))
|
||||||
|
exit()
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualUnit(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
mid_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride,
|
||||||
|
use_se,
|
||||||
|
act=None,
|
||||||
|
name=''):
|
||||||
|
super(ResidualUnit, self).__init__()
|
||||||
|
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||||
|
self.if_se = use_se
|
||||||
|
|
||||||
|
self.expand_conv = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
if_act=True,
|
||||||
|
act=act,
|
||||||
|
name=name + "_expand")
|
||||||
|
self.bottleneck_conv = ConvBNLayer(
|
||||||
|
in_channels=mid_channels,
|
||||||
|
out_channels=mid_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=int((kernel_size - 1) // 2),
|
||||||
|
groups=mid_channels,
|
||||||
|
if_act=True,
|
||||||
|
act=act,
|
||||||
|
name=name + "_depthwise")
|
||||||
|
if self.if_se:
|
||||||
|
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||||
|
self.linear_conv = ConvBNLayer(
|
||||||
|
in_channels=mid_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
if_act=False,
|
||||||
|
act=None,
|
||||||
|
name=name + "_linear")
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
x = self.expand_conv(inputs)
|
||||||
|
x = self.bottleneck_conv(x)
|
||||||
|
if self.if_se:
|
||||||
|
x = self.mid_se(x)
|
||||||
|
x = self.linear_conv(x)
|
||||||
|
if self.if_shortcut:
|
||||||
|
x = paddle.add(inputs, x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SEModule(nn.Layer):
|
||||||
|
def __init__(self, in_channels, reduction=4, name=""):
|
||||||
|
super(SEModule, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||||
|
self.conv1 = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels // reduction,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||||
|
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||||
|
self.conv2 = nn.Conv2D(
|
||||||
|
in_channels=in_channels // reduction,
|
||||||
|
out_channels=in_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
weight_attr=ParamAttr(name + "_2_weights"),
|
||||||
|
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
outputs = self.avg_pool(inputs)
|
||||||
|
outputs = self.conv1(outputs)
|
||||||
|
outputs = F.relu(outputs)
|
||||||
|
outputs = self.conv2(outputs)
|
||||||
|
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||||
|
return inputs * outputs
|
|
@ -0,0 +1,280 @@
|
||||||
|
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import ParamAttr
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
__all__ = ["ResNet"]
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBNLayer(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
kernel_size,
|
||||||
|
stride=1,
|
||||||
|
groups=1,
|
||||||
|
is_vd_mode=False,
|
||||||
|
act=None,
|
||||||
|
name=None, ):
|
||||||
|
super(ConvBNLayer, self).__init__()
|
||||||
|
|
||||||
|
self.is_vd_mode = is_vd_mode
|
||||||
|
self._pool2d_avg = nn.AvgPool2D(
|
||||||
|
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||||
|
self._conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(kernel_size - 1) // 2,
|
||||||
|
groups=groups,
|
||||||
|
weight_attr=ParamAttr(name=name + "_weights"),
|
||||||
|
bias_attr=False)
|
||||||
|
if name == "conv1":
|
||||||
|
bn_name = "bn_" + name
|
||||||
|
else:
|
||||||
|
bn_name = "bn" + name[3:]
|
||||||
|
self._batch_norm = nn.BatchNorm(
|
||||||
|
out_channels,
|
||||||
|
act=act,
|
||||||
|
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||||
|
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||||
|
moving_mean_name=bn_name + '_mean',
|
||||||
|
moving_variance_name=bn_name + '_variance')
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
if self.is_vd_mode:
|
||||||
|
inputs = self._pool2d_avg(inputs)
|
||||||
|
y = self._conv(inputs)
|
||||||
|
y = self._batch_norm(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BottleneckBlock(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
shortcut=True,
|
||||||
|
if_first=False,
|
||||||
|
name=None):
|
||||||
|
super(BottleneckBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2b")
|
||||||
|
self.conv2 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2c")
|
||||||
|
|
||||||
|
if not shortcut:
|
||||||
|
self.short = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels * 4,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
is_vd_mode=False if if_first else True,
|
||||||
|
name=name + "_branch1")
|
||||||
|
|
||||||
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
y = self.conv0(inputs)
|
||||||
|
conv1 = self.conv1(y)
|
||||||
|
conv2 = self.conv2(conv1)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
short = inputs
|
||||||
|
else:
|
||||||
|
short = self.short(inputs)
|
||||||
|
y = paddle.add(x=short, y=conv2)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Layer):
|
||||||
|
def __init__(self,
|
||||||
|
in_channels,
|
||||||
|
out_channels,
|
||||||
|
stride,
|
||||||
|
shortcut=True,
|
||||||
|
if_first=False,
|
||||||
|
name=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
self.stride = stride
|
||||||
|
self.conv0 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=stride,
|
||||||
|
act='relu',
|
||||||
|
name=name + "_branch2a")
|
||||||
|
self.conv1 = ConvBNLayer(
|
||||||
|
in_channels=out_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=3,
|
||||||
|
act=None,
|
||||||
|
name=name + "_branch2b")
|
||||||
|
|
||||||
|
if not shortcut:
|
||||||
|
self.short = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride=1,
|
||||||
|
is_vd_mode=False if if_first else True,
|
||||||
|
name=name + "_branch1")
|
||||||
|
|
||||||
|
self.shortcut = shortcut
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
y = self.conv0(inputs)
|
||||||
|
conv1 = self.conv1(y)
|
||||||
|
|
||||||
|
if self.shortcut:
|
||||||
|
short = inputs
|
||||||
|
else:
|
||||||
|
short = self.short(inputs)
|
||||||
|
y = paddle.add(x=short, y=conv1)
|
||||||
|
y = F.relu(y)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Layer):
|
||||||
|
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
|
||||||
|
self.layers = layers
|
||||||
|
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||||
|
assert layers in supported_layers, \
|
||||||
|
"supported layers are {} but input layer is {}".format(
|
||||||
|
supported_layers, layers)
|
||||||
|
|
||||||
|
if layers == 18:
|
||||||
|
depth = [2, 2, 2, 2]
|
||||||
|
elif layers == 34 or layers == 50:
|
||||||
|
depth = [3, 4, 6, 3]
|
||||||
|
elif layers == 101:
|
||||||
|
depth = [3, 4, 23, 3]
|
||||||
|
elif layers == 152:
|
||||||
|
depth = [3, 8, 36, 3]
|
||||||
|
elif layers == 200:
|
||||||
|
depth = [3, 12, 48, 3]
|
||||||
|
num_channels = [64, 256, 512,
|
||||||
|
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||||
|
num_filters = [64, 128, 256, 512]
|
||||||
|
|
||||||
|
self.conv1_1 = ConvBNLayer(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=2,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_1")
|
||||||
|
self.conv1_2 = ConvBNLayer(
|
||||||
|
in_channels=32,
|
||||||
|
out_channels=32,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_2")
|
||||||
|
self.conv1_3 = ConvBNLayer(
|
||||||
|
in_channels=32,
|
||||||
|
out_channels=64,
|
||||||
|
kernel_size=3,
|
||||||
|
stride=1,
|
||||||
|
act='relu',
|
||||||
|
name="conv1_3")
|
||||||
|
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||||
|
|
||||||
|
self.stages = []
|
||||||
|
self.out_channels = []
|
||||||
|
if layers >= 50:
|
||||||
|
for block in range(len(depth)):
|
||||||
|
block_list = []
|
||||||
|
shortcut = False
|
||||||
|
for i in range(depth[block]):
|
||||||
|
if layers in [101, 152] and block == 2:
|
||||||
|
if i == 0:
|
||||||
|
conv_name = "res" + str(block + 2) + "a"
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||||
|
else:
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
bottleneck_block = self.add_sublayer(
|
||||||
|
'bb_%d_%d' % (block, i),
|
||||||
|
BottleneckBlock(
|
||||||
|
in_channels=num_channels[block]
|
||||||
|
if i == 0 else num_filters[block] * 4,
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=2 if i == 0 and block != 0 else 1,
|
||||||
|
shortcut=shortcut,
|
||||||
|
if_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
shortcut = True
|
||||||
|
block_list.append(bottleneck_block)
|
||||||
|
self.out_channels.append(num_filters[block] * 4)
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
else:
|
||||||
|
for block in range(len(depth)):
|
||||||
|
block_list = []
|
||||||
|
shortcut = False
|
||||||
|
for i in range(depth[block]):
|
||||||
|
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||||
|
basic_block = self.add_sublayer(
|
||||||
|
'bb_%d_%d' % (block, i),
|
||||||
|
BasicBlock(
|
||||||
|
in_channels=num_channels[block]
|
||||||
|
if i == 0 else num_filters[block],
|
||||||
|
out_channels=num_filters[block],
|
||||||
|
stride=2 if i == 0 and block != 0 else 1,
|
||||||
|
shortcut=shortcut,
|
||||||
|
if_first=block == i == 0,
|
||||||
|
name=conv_name))
|
||||||
|
shortcut = True
|
||||||
|
block_list.append(basic_block)
|
||||||
|
self.out_channels.append(num_filters[block])
|
||||||
|
self.stages.append(nn.Sequential(*block_list))
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
y = self.conv1_1(inputs)
|
||||||
|
y = self.conv1_2(y)
|
||||||
|
y = self.conv1_3(y)
|
||||||
|
y = self.pool2d_max(y)
|
||||||
|
out = []
|
||||||
|
for block in self.stages:
|
||||||
|
y = block(y)
|
||||||
|
out.append(y)
|
||||||
|
return out
|
|
@ -31,8 +31,10 @@ def build_head(config):
|
||||||
from .cls_head import ClsHead
|
from .cls_head import ClsHead
|
||||||
support_dict = [
|
support_dict = [
|
||||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||||
'SRNHead', 'PGHead']
|
'SRNHead', 'PGHead', 'TableAttentionHead']
|
||||||
|
|
||||||
|
#table head
|
||||||
|
from .table_att_head import TableAttentionHead
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,240 @@
|
||||||
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class TableAttentionHead(nn.Layer):
|
||||||
|
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||||
|
super(TableAttentionHead, self).__init__()
|
||||||
|
self.input_size = in_channels[-1]
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.char_num = 280
|
||||||
|
self.elem_num = 30
|
||||||
|
|
||||||
|
self.structure_attention_cell = AttentionGRUCell(
|
||||||
|
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||||
|
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||||
|
self.loc_type = loc_type
|
||||||
|
self.in_max_len = in_max_len
|
||||||
|
|
||||||
|
if self.loc_type == 1:
|
||||||
|
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||||
|
else:
|
||||||
|
if self.in_max_len == 640:
|
||||||
|
self.loc_fea_trans = nn.Linear(400, 801)
|
||||||
|
elif self.in_max_len == 800:
|
||||||
|
self.loc_fea_trans = nn.Linear(625, 801)
|
||||||
|
else:
|
||||||
|
self.loc_fea_trans = nn.Linear(256, 801)
|
||||||
|
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, mode='Train'):
|
||||||
|
# if and else branch are both needed when you want to assign a variable
|
||||||
|
# if you modify the var in just one branch, then the modification will not work.
|
||||||
|
fea = inputs[-1]
|
||||||
|
if len(fea.shape) == 3:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||||
|
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||||
|
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||||
|
batch_size = fea.shape[0]
|
||||||
|
#sp_tokens = targets[2].numpy()
|
||||||
|
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
|
||||||
|
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
|
||||||
|
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
|
||||||
|
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
|
||||||
|
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
|
||||||
|
|
||||||
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
|
output_hiddens = []
|
||||||
|
if mode == 'Train' and targets is not None:
|
||||||
|
structure = targets[0]
|
||||||
|
for i in range(max_elem_length+1):
|
||||||
|
elem_onehots = self._char_to_onehot(
|
||||||
|
structure[:, i], onehot_dim=self.elem_num)
|
||||||
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
|
hidden, fea, elem_onehots)
|
||||||
|
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
structure_probs = self.structure_generator(output)
|
||||||
|
if self.loc_type == 1:
|
||||||
|
loc_preds = self.loc_generator(output)
|
||||||
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
|
else:
|
||||||
|
loc_fea = fea.transpose([0, 2, 1])
|
||||||
|
loc_fea = self.loc_fea_trans(loc_fea)
|
||||||
|
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||||
|
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||||
|
loc_preds = self.loc_generator(loc_concat)
|
||||||
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
|
else:
|
||||||
|
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
structure_probs = None
|
||||||
|
loc_preds = None
|
||||||
|
elem_onehots = None
|
||||||
|
outputs = None
|
||||||
|
alpha = None
|
||||||
|
max_elem_length = paddle.to_tensor(max_elem_length)
|
||||||
|
i = 0
|
||||||
|
while i < max_elem_length+1:
|
||||||
|
elem_onehots = self._char_to_onehot(
|
||||||
|
temp_elem, onehot_dim=self.elem_num)
|
||||||
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
|
hidden, fea, elem_onehots)
|
||||||
|
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||||
|
structure_probs_step = self.structure_generator(outputs)
|
||||||
|
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
structure_probs = self.structure_generator(output)
|
||||||
|
structure_probs = F.softmax(structure_probs)
|
||||||
|
if self.loc_type == 1:
|
||||||
|
loc_preds = self.loc_generator(output)
|
||||||
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
|
else:
|
||||||
|
loc_fea = fea.transpose([0, 2, 1])
|
||||||
|
loc_fea = self.loc_fea_trans(loc_fea)
|
||||||
|
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||||
|
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||||
|
loc_preds = self.loc_generator(loc_concat)
|
||||||
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
|
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||||
|
|
||||||
|
class AttentionGRUCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
return cur_hidden, alpha
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTM(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||||
|
super(AttentionLSTM, self).__init__()
|
||||||
|
self.input_size = in_channels
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_classes = out_channels
|
||||||
|
|
||||||
|
self.attention_cell = AttentionLSTMCell(
|
||||||
|
in_channels, hidden_size, out_channels, use_gru=False)
|
||||||
|
self.generator = nn.Linear(hidden_size, out_channels)
|
||||||
|
|
||||||
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||||
|
return input_ont_hot
|
||||||
|
|
||||||
|
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
num_steps = batch_max_length
|
||||||
|
|
||||||
|
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||||
|
(batch_size, self.hidden_size)))
|
||||||
|
output_hiddens = []
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
for i in range(num_steps):
|
||||||
|
# one-hot vectors for a i-th char
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets[:, i], onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||||
|
output = paddle.concat(output_hiddens, axis=1)
|
||||||
|
probs = self.generator(output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||||
|
probs = None
|
||||||
|
|
||||||
|
for i in range(num_steps):
|
||||||
|
char_onehots = self._char_to_onehot(
|
||||||
|
targets, onehot_dim=self.num_classes)
|
||||||
|
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||||
|
char_onehots)
|
||||||
|
probs_step = self.generator(hidden[0])
|
||||||
|
hidden = (hidden[1][0], hidden[1][1])
|
||||||
|
if probs is None:
|
||||||
|
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||||
|
else:
|
||||||
|
probs = paddle.concat(
|
||||||
|
[probs, paddle.unsqueeze(
|
||||||
|
probs_step, axis=1)], axis=1)
|
||||||
|
|
||||||
|
next_input = probs_step.argmax(axis=1)
|
||||||
|
|
||||||
|
targets = next_input
|
||||||
|
|
||||||
|
return probs
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionLSTMCell(nn.Layer):
|
||||||
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
|
super(AttentionLSTMCell, self).__init__()
|
||||||
|
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||||
|
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||||
|
if not use_gru:
|
||||||
|
self.rnn = nn.LSTMCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
else:
|
||||||
|
self.rnn = nn.GRUCell(
|
||||||
|
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
|
||||||
|
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||||
|
batch_H_proj = self.i2h(batch_H)
|
||||||
|
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||||
|
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||||
|
res = paddle.tanh(res)
|
||||||
|
e = self.score(res)
|
||||||
|
|
||||||
|
alpha = F.softmax(e, axis=1)
|
||||||
|
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||||
|
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||||
|
concat_context = paddle.concat([context, char_onehots], 1)
|
||||||
|
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||||
|
|
||||||
|
return cur_hidden, alpha
|
|
@ -21,7 +21,8 @@ def build_neck(config):
|
||||||
from .sast_fpn import SASTFPN
|
from .sast_fpn import SASTFPN
|
||||||
from .rnn import SequenceEncoder
|
from .rnn import SequenceEncoder
|
||||||
from .pg_fpn import PGFPN
|
from .pg_fpn import PGFPN
|
||||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
from .table_fpn import TableFPN
|
||||||
|
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||||
|
|
||||||
module_name = config.pop('name')
|
module_name = config.pop('name')
|
||||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle import nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
from paddle import ParamAttr
|
||||||
|
|
||||||
|
|
||||||
|
class TableFPN(nn.Layer):
|
||||||
|
def __init__(self, in_channels, out_channels, **kwargs):
|
||||||
|
super(TableFPN, self).__init__()
|
||||||
|
self.out_channels = 512
|
||||||
|
weight_attr = paddle.nn.initializer.KaimingUniform()
|
||||||
|
self.in2_conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels[0],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_51.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.in3_conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels[1],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
stride = 1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_50.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.in4_conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels[2],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_49.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.in5_conv = nn.Conv2D(
|
||||||
|
in_channels=in_channels[3],
|
||||||
|
out_channels=self.out_channels,
|
||||||
|
kernel_size=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_48.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.p5_conv = nn.Conv2D(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_52.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.p4_conv = nn.Conv2D(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_53.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.p3_conv = nn.Conv2D(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_54.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.p2_conv = nn.Conv2D(
|
||||||
|
in_channels=self.out_channels,
|
||||||
|
out_channels=self.out_channels // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_55.w_0', initializer=weight_attr),
|
||||||
|
bias_attr=False)
|
||||||
|
self.fuse_conv = nn.Conv2D(
|
||||||
|
in_channels=self.out_channels * 4,
|
||||||
|
out_channels=512,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
weight_attr=ParamAttr(
|
||||||
|
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
c2, c3, c4, c5 = x
|
||||||
|
|
||||||
|
in5 = self.in5_conv(c5)
|
||||||
|
in4 = self.in4_conv(c4)
|
||||||
|
in3 = self.in3_conv(c3)
|
||||||
|
in2 = self.in2_conv(c2)
|
||||||
|
|
||||||
|
out4 = in4 + F.upsample(
|
||||||
|
in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
|
||||||
|
out3 = in3 + F.upsample(
|
||||||
|
out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
|
||||||
|
out2 = in2 + F.upsample(
|
||||||
|
out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
|
||||||
|
|
||||||
|
p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||||
|
p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||||
|
p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
|
||||||
|
fuse = paddle.concat([in5, p4, p3, p2], axis=1)
|
||||||
|
fuse_conv = self.fuse_conv(fuse) * 0.005
|
||||||
|
return [c5 + fuse_conv]
|
|
@ -325,8 +325,14 @@ class TableLabelDecode(object):
|
||||||
""" """
|
""" """
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
max_text_length,
|
||||||
|
max_elem_length,
|
||||||
|
max_cell_num,
|
||||||
character_dict_path,
|
character_dict_path,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
self.max_text_length = max_text_length
|
||||||
|
self.max_elem_length = max_elem_length
|
||||||
|
self.max_cell_num = max_cell_num
|
||||||
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
list_character, list_elem = self.load_char_elem_dict(character_dict_path)
|
||||||
list_character = self.add_special_char(list_character)
|
list_character = self.add_special_char(list_character)
|
||||||
list_elem = self.add_special_char(list_elem)
|
list_elem = self.add_special_char(list_elem)
|
||||||
|
@ -363,6 +369,18 @@ class TableLabelDecode(object):
|
||||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||||
return list_character
|
return list_character
|
||||||
|
|
||||||
|
def get_sp_tokens(self):
|
||||||
|
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||||
|
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||||
|
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||||
|
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||||
|
elem_char_idx1 = self.dict_elem['<td>']
|
||||||
|
elem_char_idx2 = self.dict_elem['<td']
|
||||||
|
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||||
|
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||||
|
self.max_elem_length, self.max_cell_num])
|
||||||
|
return sp_tokens
|
||||||
|
|
||||||
def __call__(self, preds):
|
def __call__(self, preds):
|
||||||
structure_probs = preds['structure_probs']
|
structure_probs = preds['structure_probs']
|
||||||
loc_preds = preds['loc_preds']
|
loc_preds = preds['loc_preds']
|
||||||
|
|
|
@ -48,6 +48,7 @@ def main():
|
||||||
getattr(post_process_class, 'character'))
|
getattr(post_process_class, 'character'))
|
||||||
model = build_model(config['Architecture'])
|
model = build_model(config['Architecture'])
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
model_type = config['Architecture']['model_type']
|
||||||
|
|
||||||
best_model_dict = init_model(config, model)
|
best_model_dict = init_model(config, model)
|
||||||
if len(best_model_dict):
|
if len(best_model_dict):
|
||||||
|
@ -60,7 +61,7 @@ def main():
|
||||||
|
|
||||||
# start eval
|
# start eval
|
||||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||||
eval_class, use_srn)
|
eval_class, model_type, use_srn)
|
||||||
logger.info('metric eval ***************')
|
logger.info('metric eval ***************')
|
||||||
for k, v in metric.items():
|
for k, v in metric.items():
|
||||||
logger.info('{}:{}'.format(k, v))
|
logger.info('{}:{}'.format(k, v))
|
||||||
|
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
sys.path.append(__dir__)
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||||
|
|
||||||
|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
from paddle.jit import to_static
|
||||||
|
|
||||||
|
from ppocr.data import create_operators, transform
|
||||||
|
from ppocr.modeling.architectures import build_model
|
||||||
|
from ppocr.postprocess import build_post_process
|
||||||
|
from ppocr.utils.save_load import init_model
|
||||||
|
from ppocr.utils.utility import get_image_file_list
|
||||||
|
import tools.program as program
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
def main(config, device, logger, vdl_writer):
|
||||||
|
global_config = config['Global']
|
||||||
|
|
||||||
|
# build post process
|
||||||
|
post_process_class = build_post_process(config['PostProcess'],
|
||||||
|
global_config)
|
||||||
|
|
||||||
|
# build model
|
||||||
|
if hasattr(post_process_class, 'character'):
|
||||||
|
config['Architecture']["Head"]['out_channels'] = len(
|
||||||
|
getattr(post_process_class, 'character'))
|
||||||
|
|
||||||
|
model = build_model(config['Architecture'])
|
||||||
|
|
||||||
|
init_model(config, model, logger)
|
||||||
|
|
||||||
|
# create data ops
|
||||||
|
transforms = []
|
||||||
|
use_padding = False
|
||||||
|
for op in config['Eval']['dataset']['transforms']:
|
||||||
|
op_name = list(op)[0]
|
||||||
|
if 'Label' in op_name:
|
||||||
|
continue
|
||||||
|
if op_name == 'KeepKeys':
|
||||||
|
op[op_name]['keep_keys'] = ['image']
|
||||||
|
if op_name == "ResizeTableImage":
|
||||||
|
use_padding = True
|
||||||
|
padding_max_len = op['ResizeTableImage']['max_len']
|
||||||
|
transforms.append(op)
|
||||||
|
|
||||||
|
global_config['infer_mode'] = True
|
||||||
|
ops = create_operators(transforms, global_config)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
for file in get_image_file_list(config['Global']['infer_img']):
|
||||||
|
logger.info("infer_img: {}".format(file))
|
||||||
|
with open(file, 'rb') as f:
|
||||||
|
img = f.read()
|
||||||
|
data = {'image': img}
|
||||||
|
batch = transform(data, ops)
|
||||||
|
sp_tokens = post_process_class.get_sp_tokens()
|
||||||
|
targets = [[], [], paddle.to_tensor([sp_tokens])]
|
||||||
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
|
images = paddle.to_tensor(images)
|
||||||
|
preds = model(images, data=targets, mode='Test')
|
||||||
|
post_result = post_process_class(preds)
|
||||||
|
res_html_code = post_result['res_html_code']
|
||||||
|
res_loc = post_result['res_loc']
|
||||||
|
img = cv2.imread(file)
|
||||||
|
imgh, imgw = img.shape[0:2]
|
||||||
|
res_loc_final = []
|
||||||
|
for rno in range(len(res_loc[0])):
|
||||||
|
x0, y0, x1, y1 = res_loc[0][rno]
|
||||||
|
left = max(int(imgw * x0), 0)
|
||||||
|
top = max(int(imgh * y0), 0)
|
||||||
|
right = min(int(imgw * x1), imgw - 1)
|
||||||
|
bottom = min(int(imgh * y1), imgh - 1)
|
||||||
|
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
|
||||||
|
res_loc_final.append([left, top, right, bottom])
|
||||||
|
res_loc_str = json.dumps(res_loc_final)
|
||||||
|
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
|
||||||
|
logger.info("success!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config, device, logger, vdl_writer = program.preprocess()
|
||||||
|
main(config, device, logger, vdl_writer)
|
||||||
|
|
|
@ -186,7 +186,8 @@ def train(config,
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||||
|
model_type = config['Architecture']['model_type']
|
||||||
|
|
||||||
if 'start_epoch' in best_model_dict:
|
if 'start_epoch' in best_model_dict:
|
||||||
start_epoch = best_model_dict['start_epoch']
|
start_epoch = best_model_dict['start_epoch']
|
||||||
else:
|
else:
|
||||||
|
@ -211,6 +212,9 @@ def train(config,
|
||||||
others = batch[-4:]
|
others = batch[-4:]
|
||||||
preds = model(images, others)
|
preds = model(images, others)
|
||||||
model_average = True
|
model_average = True
|
||||||
|
elif model_type == "table":
|
||||||
|
others = batch[1:]
|
||||||
|
preds = model(images, others)
|
||||||
else:
|
else:
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
loss = loss_class(preds, batch)
|
loss = loss_class(preds, batch)
|
||||||
|
@ -232,8 +236,11 @@ def train(config,
|
||||||
|
|
||||||
if cal_metric_during_train: # only rec and cls need
|
if cal_metric_during_train: # only rec and cls need
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
post_result = post_process_class(preds, batch[1])
|
if model_type == 'table':
|
||||||
eval_class(post_result, batch)
|
eval_class(preds, batch)
|
||||||
|
else:
|
||||||
|
post_result = post_process_class(preds, batch[1])
|
||||||
|
eval_class(post_result, batch)
|
||||||
metric = eval_class.get_metric()
|
metric = eval_class.get_metric()
|
||||||
train_stats.update(metric)
|
train_stats.update(metric)
|
||||||
|
|
||||||
|
@ -337,7 +344,7 @@ def train(config,
|
||||||
|
|
||||||
|
|
||||||
def eval(model, valid_dataloader, post_process_class, eval_class,
|
def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||||
use_srn=False):
|
model_type, use_srn=False):
|
||||||
model.eval()
|
model.eval()
|
||||||
with paddle.no_grad():
|
with paddle.no_grad():
|
||||||
total_frame = 0.0
|
total_frame = 0.0
|
||||||
|
@ -359,10 +366,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class,
|
||||||
|
|
||||||
batch = [item.numpy() for item in batch]
|
batch = [item.numpy() for item in batch]
|
||||||
# Obtain usable results from post-processing methods
|
# Obtain usable results from post-processing methods
|
||||||
post_result = post_process_class(preds, batch[1])
|
|
||||||
total_time += time.time() - start
|
total_time += time.time() - start
|
||||||
# Evaluate the results of the current batch
|
# Evaluate the results of the current batch
|
||||||
eval_class(post_result, batch)
|
if model_type == 'table':
|
||||||
|
eval_class(preds, batch)
|
||||||
|
else:
|
||||||
|
post_result = post_process_class(preds, batch[1])
|
||||||
|
eval_class(post_result, batch)
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
total_frame += len(images)
|
total_frame += len(images)
|
||||||
# Get final metric,eg. acc or hmean
|
# Get final metric,eg. acc or hmean
|
||||||
|
@ -386,7 +396,7 @@ def preprocess(is_train=False):
|
||||||
alg = config['Architecture']['algorithm']
|
alg = config['Architecture']['algorithm']
|
||||||
assert alg in [
|
assert alg in [
|
||||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||||
'CLS', 'PGNet', 'Distillation'
|
'CLS', 'PGNet', 'Distillation', 'TableAttn'
|
||||||
]
|
]
|
||||||
|
|
||||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||||
|
|
Loading…
Reference in New Issue