refine
This commit is contained in:
parent
7c8b2c8d19
commit
16c247ac46
|
@ -1,13 +1,12 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 40
|
||||
epoch_num: 50
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/table_mv3/
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
save_epoch_step: 5
|
||||
# evaluation is run every 400 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 400]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
|
@ -18,19 +17,20 @@ Global:
|
|||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 100
|
||||
max_elem_length: 800
|
||||
max_elem_length: 500
|
||||
max_cell_num: 500
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 0.0001
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
@ -41,12 +41,12 @@ Architecture:
|
|||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 1.0
|
||||
model_name: large
|
||||
model_name: small
|
||||
disable_se: True
|
||||
Head:
|
||||
name: TableAttentionHead # AttentionHead
|
||||
hidden_size: 256 #
|
||||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
l2_decay: 0.00001
|
||||
# loc_type: 1
|
||||
loc_type: 2
|
||||
|
||||
Loss:
|
||||
|
@ -86,7 +86,7 @@ Train:
|
|||
shuffle: True
|
||||
batch_size_per_card: 32
|
||||
drop_last: True
|
||||
num_workers: 4
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
|
@ -113,4 +113,4 @@ Eval:
|
|||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 4
|
||||
num_workers: 1
|
||||
|
|
|
@ -412,7 +412,6 @@ class TableLabelEncode(object):
|
|||
return None
|
||||
elem_num = len(structure)
|
||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||
# structure = [0] + structure + [0]
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
||||
structure = np.array(structure)
|
||||
data['structure'] = structure
|
||||
|
@ -443,8 +442,6 @@ class TableLabelEncode(object):
|
|||
if cand_span_idx < (self.max_elem_length + 2):
|
||||
if structure[cand_span_idx] in span_idx_list:
|
||||
structure_mask[cand_span_idx] = span_weight
|
||||
# structure_mask[td_idx] = self.span_weight
|
||||
# structure_mask[cand_span_idx] = self.span_weight
|
||||
|
||||
data['bbox_list'] = bbox_list
|
||||
data['bbox_list_mask'] = bbox_list_mask
|
||||
|
@ -458,23 +455,6 @@ class TableLabelEncode(object):
|
|||
self.max_elem_length, self.max_cell_num, elem_num])
|
||||
return data
|
||||
|
||||
########
|
||||
# for char decode
|
||||
# cell_list = []
|
||||
# for cell in cells:
|
||||
# char_list = cell['tokens']
|
||||
# cell = self.encode(char_list, 'char')
|
||||
# if cell is None:
|
||||
# return None
|
||||
# cell = [0] + cell + [len(self.dict_character) - 1]
|
||||
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
|
||||
# cell_list.append(cell)
|
||||
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
|
||||
# cell_list = np.array(cell_list)
|
||||
# cell_list_padding[0:cell_list.shape[0]] = cell_list
|
||||
# data['cells'] = cell_list_padding
|
||||
# return data
|
||||
|
||||
def encode(self, text, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -19,6 +19,7 @@ import json
|
|||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class PubTabDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PubTabDataSet, self).__init__()
|
||||
|
@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
|
|||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def load_hard_select_prob(self):
|
||||
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
|
||||
img_select_prob = {}
|
||||
with open(label_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for lno in range(len(lines)):
|
||||
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
|
||||
img_name = substr[0].strip(":")
|
||||
score = float(substr[1])
|
||||
if score <= 0.8:
|
||||
img_select_prob[img_name] = self.hard_prob[0]
|
||||
elif score <= 0.98:
|
||||
img_select_prob[img_name] = self.hard_prob[1]
|
||||
else:
|
||||
img_select_prob[img_name] = self.hard_prob[2]
|
||||
return img_select_prob
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
data_line = self.data_lines[idx]
|
||||
|
@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
|
|||
table_type = "simple"
|
||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||
table_type = "complex"
|
||||
# if self.table_select_type != table_type:
|
||||
# select_flag = False
|
||||
if table_type == "complex":
|
||||
if self.table_select_prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -21,13 +21,16 @@ import paddle.nn as nn
|
|||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.char_num = 280
|
||||
self.elem_num = 30
|
||||
self.max_text_length = 100
|
||||
self.max_elem_length = 500
|
||||
self.max_cell_num = 500
|
||||
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||
|
@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
|
|||
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||
else:
|
||||
if self.in_max_len == 640:
|
||||
self.loc_fea_trans = nn.Linear(400, 801)
|
||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
||||
elif self.in_max_len == 800:
|
||||
self.loc_fea_trans = nn.Linear(625, 801)
|
||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
||||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, 801)
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
|
@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
|
|||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
batch_size = fea.shape[0]
|
||||
#sp_tokens = targets[2].numpy()
|
||||
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
|
||||
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
|
||||
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
|
||||
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
|
||||
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
if mode == 'Train' and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(max_elem_length+1):
|
||||
for i in range(self.max_elem_length+1):
|
||||
elem_onehots = self._char_to_onehot(
|
||||
structure[:, i], onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
|
@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
|
|||
elem_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
max_elem_length = paddle.to_tensor(max_elem_length)
|
||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||
i = 0
|
||||
while i < max_elem_length+1:
|
||||
elem_onehots = self._char_to_onehot(
|
||||
|
@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
|
|||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||
|
||||
|
||||
class AttentionGRUCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionGRUCell, self).__init__()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
|
|||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_51.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride = 1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_50.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_49.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_48.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_52.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_53.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_54.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_55.w_0', initializer=weight_attr),
|
||||
weight_attr=ParamAttr(initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.fuse_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels * 4,
|
||||
out_channels=512,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
|
||||
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
c2, c3, c4, c5 = x
|
||||
|
|
|
@ -369,18 +369,6 @@ class TableLabelDecode(object):
|
|||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def get_sp_tokens(self):
|
||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num])
|
||||
return sp_tokens
|
||||
|
||||
def __call__(self, preds):
|
||||
structure_probs = preds['structure_probs']
|
||||
loc_preds = preds['loc_preds']
|
||||
|
|
|
@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
|
|||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
|
||||
elif arch_config["model_type"] == "table":
|
||||
infer_shape = [3, 488, 488]
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
|
|
|
@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
|
|||
img = f.read()
|
||||
data = {'image': img}
|
||||
batch = transform(data, ops)
|
||||
sp_tokens = post_process_class.get_sp_tokens()
|
||||
targets = [[], [], paddle.to_tensor([sp_tokens])]
|
||||
images = np.expand_dims(batch[0], axis=0)
|
||||
images = paddle.to_tensor(images)
|
||||
preds = model(images, data=targets, mode='Test')
|
||||
preds = model(images, data=None, mode='Test')
|
||||
post_result = post_process_class(preds)
|
||||
res_html_code = post_result['res_html_code']
|
||||
res_loc = post_result['res_loc']
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -276,6 +276,7 @@ def train(config,
|
|||
valid_dataloader,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
"table",
|
||||
use_srn=use_srn)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
|
|
Loading…
Reference in New Issue