refine
This commit is contained in:
parent
7c8b2c8d19
commit
16c247ac46
|
@ -1,13 +1,12 @@
|
||||||
Global:
|
Global:
|
||||||
use_gpu: true
|
use_gpu: true
|
||||||
epoch_num: 40
|
epoch_num: 50
|
||||||
log_smooth_window: 20
|
log_smooth_window: 20
|
||||||
print_batch_step: 5
|
print_batch_step: 5
|
||||||
save_model_dir: ./output/table_mv3/
|
save_model_dir: ./output/table_mv3/
|
||||||
save_epoch_step: 3
|
save_epoch_step: 5
|
||||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
# evaluation is run every 400 iterations after the 0th iteration
|
||||||
eval_batch_step: [0, 400]
|
eval_batch_step: [0, 400]
|
||||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
|
||||||
cal_metric_during_train: True
|
cal_metric_during_train: True
|
||||||
pretrained_model:
|
pretrained_model:
|
||||||
checkpoints:
|
checkpoints:
|
||||||
|
@ -18,19 +17,20 @@ Global:
|
||||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||||
character_type: en
|
character_type: en
|
||||||
max_text_length: 100
|
max_text_length: 100
|
||||||
max_elem_length: 800
|
max_elem_length: 500
|
||||||
max_cell_num: 500
|
max_cell_num: 500
|
||||||
infer_mode: False
|
infer_mode: False
|
||||||
process_total_num: 0
|
process_total_num: 0
|
||||||
process_cut_num: 0
|
process_cut_num: 0
|
||||||
|
|
||||||
|
|
||||||
Optimizer:
|
Optimizer:
|
||||||
name: Adam
|
name: Adam
|
||||||
beta1: 0.9
|
beta1: 0.9
|
||||||
beta2: 0.999
|
beta2: 0.999
|
||||||
clip_norm: 5.0
|
clip_norm: 5.0
|
||||||
lr:
|
lr:
|
||||||
learning_rate: 0.0001
|
learning_rate: 0.001
|
||||||
regularizer:
|
regularizer:
|
||||||
name: 'L2'
|
name: 'L2'
|
||||||
factor: 0.00000
|
factor: 0.00000
|
||||||
|
@ -41,12 +41,12 @@ Architecture:
|
||||||
Backbone:
|
Backbone:
|
||||||
name: MobileNetV3
|
name: MobileNetV3
|
||||||
scale: 1.0
|
scale: 1.0
|
||||||
model_name: large
|
model_name: small
|
||||||
|
disable_se: True
|
||||||
Head:
|
Head:
|
||||||
name: TableAttentionHead # AttentionHead
|
name: TableAttentionHead
|
||||||
hidden_size: 256 #
|
hidden_size: 256
|
||||||
l2_decay: 0.00001
|
l2_decay: 0.00001
|
||||||
# loc_type: 1
|
|
||||||
loc_type: 2
|
loc_type: 2
|
||||||
|
|
||||||
Loss:
|
Loss:
|
||||||
|
@ -86,7 +86,7 @@ Train:
|
||||||
shuffle: True
|
shuffle: True
|
||||||
batch_size_per_card: 32
|
batch_size_per_card: 32
|
||||||
drop_last: True
|
drop_last: True
|
||||||
num_workers: 4
|
num_workers: 1
|
||||||
|
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
|
@ -113,4 +113,4 @@ Eval:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
batch_size_per_card: 16
|
batch_size_per_card: 16
|
||||||
num_workers: 4
|
num_workers: 1
|
||||||
|
|
|
@ -412,7 +412,6 @@ class TableLabelEncode(object):
|
||||||
return None
|
return None
|
||||||
elem_num = len(structure)
|
elem_num = len(structure)
|
||||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||||
# structure = [0] + structure + [0]
|
|
||||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
structure = structure + [0] * (self.max_elem_length + 2 - len(structure))
|
||||||
structure = np.array(structure)
|
structure = np.array(structure)
|
||||||
data['structure'] = structure
|
data['structure'] = structure
|
||||||
|
@ -443,8 +442,6 @@ class TableLabelEncode(object):
|
||||||
if cand_span_idx < (self.max_elem_length + 2):
|
if cand_span_idx < (self.max_elem_length + 2):
|
||||||
if structure[cand_span_idx] in span_idx_list:
|
if structure[cand_span_idx] in span_idx_list:
|
||||||
structure_mask[cand_span_idx] = span_weight
|
structure_mask[cand_span_idx] = span_weight
|
||||||
# structure_mask[td_idx] = self.span_weight
|
|
||||||
# structure_mask[cand_span_idx] = self.span_weight
|
|
||||||
|
|
||||||
data['bbox_list'] = bbox_list
|
data['bbox_list'] = bbox_list
|
||||||
data['bbox_list_mask'] = bbox_list_mask
|
data['bbox_list_mask'] = bbox_list_mask
|
||||||
|
@ -458,23 +455,6 @@ class TableLabelEncode(object):
|
||||||
self.max_elem_length, self.max_cell_num, elem_num])
|
self.max_elem_length, self.max_cell_num, elem_num])
|
||||||
return data
|
return data
|
||||||
|
|
||||||
########
|
|
||||||
# for char decode
|
|
||||||
# cell_list = []
|
|
||||||
# for cell in cells:
|
|
||||||
# char_list = cell['tokens']
|
|
||||||
# cell = self.encode(char_list, 'char')
|
|
||||||
# if cell is None:
|
|
||||||
# return None
|
|
||||||
# cell = [0] + cell + [len(self.dict_character) - 1]
|
|
||||||
# cell = cell + [0] * (self.max_text_length + 2 - len(cell))
|
|
||||||
# cell_list.append(cell)
|
|
||||||
# cell_list_padding = np.zeros((self.max_cell_num, self.max_text_length + 2))
|
|
||||||
# cell_list = np.array(cell_list)
|
|
||||||
# cell_list_padding[0:cell_list.shape[0]] = cell_list
|
|
||||||
# data['cells'] = cell_list_padding
|
|
||||||
# return data
|
|
||||||
|
|
||||||
def encode(self, text, char_or_elem):
|
def encode(self, text, char_or_elem):
|
||||||
"""convert text-label into text-index.
|
"""convert text-label into text-index.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -19,6 +19,7 @@ import json
|
||||||
|
|
||||||
from .imaug import transform, create_operators
|
from .imaug import transform, create_operators
|
||||||
|
|
||||||
|
|
||||||
class PubTabDataSet(Dataset):
|
class PubTabDataSet(Dataset):
|
||||||
def __init__(self, config, mode, logger, seed=None):
|
def __init__(self, config, mode, logger, seed=None):
|
||||||
super(PubTabDataSet, self).__init__()
|
super(PubTabDataSet, self).__init__()
|
||||||
|
@ -58,23 +59,6 @@ class PubTabDataSet(Dataset):
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
|
||||||
def load_hard_select_prob(self):
|
|
||||||
label_path = "./pretrained_model/teds_score_exp5_st2_train.txt"
|
|
||||||
img_select_prob = {}
|
|
||||||
with open(label_path, "rb") as fin:
|
|
||||||
lines = fin.readlines()
|
|
||||||
for lno in range(len(lines)):
|
|
||||||
substr = lines[lno].decode('utf-8').strip("\n").split(" ")
|
|
||||||
img_name = substr[0].strip(":")
|
|
||||||
score = float(substr[1])
|
|
||||||
if score <= 0.8:
|
|
||||||
img_select_prob[img_name] = self.hard_prob[0]
|
|
||||||
elif score <= 0.98:
|
|
||||||
img_select_prob[img_name] = self.hard_prob[1]
|
|
||||||
else:
|
|
||||||
img_select_prob[img_name] = self.hard_prob[2]
|
|
||||||
return img_select_prob
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
try:
|
try:
|
||||||
data_line = self.data_lines[idx]
|
data_line = self.data_lines[idx]
|
||||||
|
@ -93,8 +77,6 @@ class PubTabDataSet(Dataset):
|
||||||
table_type = "simple"
|
table_type = "simple"
|
||||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||||
table_type = "complex"
|
table_type = "complex"
|
||||||
# if self.table_select_type != table_type:
|
|
||||||
# select_flag = False
|
|
||||||
if table_type == "complex":
|
if table_type == "complex":
|
||||||
if self.table_select_prob < random.uniform(0, 1):
|
if self.table_select_prob < random.uniform(0, 1):
|
||||||
select_flag = False
|
select_flag = False
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
|
|
@ -21,13 +21,16 @@ import paddle.nn as nn
|
||||||
import paddle.nn.functional as F
|
import paddle.nn.functional as F
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
class TableAttentionHead(nn.Layer):
|
class TableAttentionHead(nn.Layer):
|
||||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||||
super(TableAttentionHead, self).__init__()
|
super(TableAttentionHead, self).__init__()
|
||||||
self.input_size = in_channels[-1]
|
self.input_size = in_channels[-1]
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.char_num = 280
|
|
||||||
self.elem_num = 30
|
self.elem_num = 30
|
||||||
|
self.max_text_length = 100
|
||||||
|
self.max_elem_length = 500
|
||||||
|
self.max_cell_num = 500
|
||||||
|
|
||||||
self.structure_attention_cell = AttentionGRUCell(
|
self.structure_attention_cell = AttentionGRUCell(
|
||||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||||
|
@ -39,11 +42,11 @@ class TableAttentionHead(nn.Layer):
|
||||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||||
else:
|
else:
|
||||||
if self.in_max_len == 640:
|
if self.in_max_len == 640:
|
||||||
self.loc_fea_trans = nn.Linear(400, 801)
|
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
||||||
elif self.in_max_len == 800:
|
elif self.in_max_len == 800:
|
||||||
self.loc_fea_trans = nn.Linear(625, 801)
|
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
||||||
else:
|
else:
|
||||||
self.loc_fea_trans = nn.Linear(256, 801)
|
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
||||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||||
|
|
||||||
def _char_to_onehot(self, input_char, onehot_dim):
|
def _char_to_onehot(self, input_char, onehot_dim):
|
||||||
|
@ -61,18 +64,12 @@ class TableAttentionHead(nn.Layer):
|
||||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||||
batch_size = fea.shape[0]
|
batch_size = fea.shape[0]
|
||||||
#sp_tokens = targets[2].numpy()
|
|
||||||
#char_beg_idx, char_end_idx = sp_tokens[0, 0:2]
|
|
||||||
#elem_beg_idx, elem_end_idx = sp_tokens[0, 2:4]
|
|
||||||
#elem_char_idx1, elem_char_idx2 = sp_tokens[0, 4:6]
|
|
||||||
#max_text_length, max_elem_length, max_cell_num = sp_tokens[0, 6:9]
|
|
||||||
max_text_length, max_elem_length, max_cell_num = 100, 800, 500
|
|
||||||
|
|
||||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||||
output_hiddens = []
|
output_hiddens = []
|
||||||
if mode == 'Train' and targets is not None:
|
if mode == 'Train' and targets is not None:
|
||||||
structure = targets[0]
|
structure = targets[0]
|
||||||
for i in range(max_elem_length+1):
|
for i in range(self.max_elem_length+1):
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
structure[:, i], onehot_dim=self.elem_num)
|
structure[:, i], onehot_dim=self.elem_num)
|
||||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||||
|
@ -97,7 +94,7 @@ class TableAttentionHead(nn.Layer):
|
||||||
elem_onehots = None
|
elem_onehots = None
|
||||||
outputs = None
|
outputs = None
|
||||||
alpha = None
|
alpha = None
|
||||||
max_elem_length = paddle.to_tensor(max_elem_length)
|
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||||
i = 0
|
i = 0
|
||||||
while i < max_elem_length+1:
|
while i < max_elem_length+1:
|
||||||
elem_onehots = self._char_to_onehot(
|
elem_onehots = self._char_to_onehot(
|
||||||
|
@ -124,6 +121,7 @@ class TableAttentionHead(nn.Layer):
|
||||||
loc_preds = F.sigmoid(loc_preds)
|
loc_preds = F.sigmoid(loc_preds)
|
||||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||||
|
|
||||||
|
|
||||||
class AttentionGRUCell(nn.Layer):
|
class AttentionGRUCell(nn.Layer):
|
||||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||||
super(AttentionGRUCell, self).__init__()
|
super(AttentionGRUCell, self).__init__()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
|
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -31,70 +31,61 @@ class TableFPN(nn.Layer):
|
||||||
in_channels=in_channels[0],
|
in_channels=in_channels[0],
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_51.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.in3_conv = nn.Conv2D(
|
self.in3_conv = nn.Conv2D(
|
||||||
in_channels=in_channels[1],
|
in_channels=in_channels[1],
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
stride = 1,
|
stride = 1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_50.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.in4_conv = nn.Conv2D(
|
self.in4_conv = nn.Conv2D(
|
||||||
in_channels=in_channels[2],
|
in_channels=in_channels[2],
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_49.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.in5_conv = nn.Conv2D(
|
self.in5_conv = nn.Conv2D(
|
||||||
in_channels=in_channels[3],
|
in_channels=in_channels[3],
|
||||||
out_channels=self.out_channels,
|
out_channels=self.out_channels,
|
||||||
kernel_size=1,
|
kernel_size=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_48.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.p5_conv = nn.Conv2D(
|
self.p5_conv = nn.Conv2D(
|
||||||
in_channels=self.out_channels,
|
in_channels=self.out_channels,
|
||||||
out_channels=self.out_channels // 4,
|
out_channels=self.out_channels // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_52.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.p4_conv = nn.Conv2D(
|
self.p4_conv = nn.Conv2D(
|
||||||
in_channels=self.out_channels,
|
in_channels=self.out_channels,
|
||||||
out_channels=self.out_channels // 4,
|
out_channels=self.out_channels // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_53.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.p3_conv = nn.Conv2D(
|
self.p3_conv = nn.Conv2D(
|
||||||
in_channels=self.out_channels,
|
in_channels=self.out_channels,
|
||||||
out_channels=self.out_channels // 4,
|
out_channels=self.out_channels // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_54.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.p2_conv = nn.Conv2D(
|
self.p2_conv = nn.Conv2D(
|
||||||
in_channels=self.out_channels,
|
in_channels=self.out_channels,
|
||||||
out_channels=self.out_channels // 4,
|
out_channels=self.out_channels // 4,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr),
|
||||||
name='conv2d_55.w_0', initializer=weight_attr),
|
|
||||||
bias_attr=False)
|
bias_attr=False)
|
||||||
self.fuse_conv = nn.Conv2D(
|
self.fuse_conv = nn.Conv2D(
|
||||||
in_channels=self.out_channels * 4,
|
in_channels=self.out_channels * 4,
|
||||||
out_channels=512,
|
out_channels=512,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
weight_attr=ParamAttr(
|
weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
|
||||||
name='conv2d_fuse.w_0', initializer=weight_attr), bias_attr=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
c2, c3, c4, c5 = x
|
c2, c3, c4, c5 = x
|
||||||
|
|
|
@ -369,18 +369,6 @@ class TableLabelDecode(object):
|
||||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||||
return list_character
|
return list_character
|
||||||
|
|
||||||
def get_sp_tokens(self):
|
|
||||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
|
||||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
|
||||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
|
||||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
|
||||||
elem_char_idx1 = self.dict_elem['<td>']
|
|
||||||
elem_char_idx2 = self.dict_elem['<td']
|
|
||||||
sp_tokens = np.array([char_beg_idx, char_end_idx, elem_beg_idx,
|
|
||||||
elem_end_idx, elem_char_idx1, elem_char_idx2, self.max_text_length,
|
|
||||||
self.max_elem_length, self.max_cell_num])
|
|
||||||
return sp_tokens
|
|
||||||
|
|
||||||
def __call__(self, preds):
|
def __call__(self, preds):
|
||||||
structure_probs = preds['structure_probs']
|
structure_probs = preds['structure_probs']
|
||||||
loc_preds = preds['loc_preds']
|
loc_preds = preds['loc_preds']
|
||||||
|
|
|
@ -60,7 +60,8 @@ def export_single_model(model, arch_config, save_path, logger):
|
||||||
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
|
||||||
)
|
)
|
||||||
infer_shape[-1] = 100
|
infer_shape[-1] = 100
|
||||||
|
elif arch_config["model_type"] == "table":
|
||||||
|
infer_shape = [3, 488, 488]
|
||||||
model = to_static(
|
model = to_static(
|
||||||
model,
|
model,
|
||||||
input_spec=[
|
input_spec=[
|
||||||
|
|
|
@ -79,11 +79,9 @@ def main(config, device, logger, vdl_writer):
|
||||||
img = f.read()
|
img = f.read()
|
||||||
data = {'image': img}
|
data = {'image': img}
|
||||||
batch = transform(data, ops)
|
batch = transform(data, ops)
|
||||||
sp_tokens = post_process_class.get_sp_tokens()
|
|
||||||
targets = [[], [], paddle.to_tensor([sp_tokens])]
|
|
||||||
images = np.expand_dims(batch[0], axis=0)
|
images = np.expand_dims(batch[0], axis=0)
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
preds = model(images, data=targets, mode='Test')
|
preds = model(images, data=None, mode='Test')
|
||||||
post_result = post_process_class(preds)
|
post_result = post_process_class(preds)
|
||||||
res_html_code = post_result['res_html_code']
|
res_html_code = post_result['res_html_code']
|
||||||
res_loc = post_result['res_loc']
|
res_loc = post_result['res_loc']
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -276,6 +276,7 @@ def train(config,
|
||||||
valid_dataloader,
|
valid_dataloader,
|
||||||
post_process_class,
|
post_process_class,
|
||||||
eval_class,
|
eval_class,
|
||||||
|
"table",
|
||||||
use_srn=use_srn)
|
use_srn=use_srn)
|
||||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||||
|
|
Loading…
Reference in New Issue