polish seed code
This commit is contained in:
parent
1effa5f3fe
commit
a14f8da961
|
@ -19,7 +19,6 @@ Global:
|
|||
max_text_length: 100
|
||||
infer_mode: False
|
||||
use_space_char: False
|
||||
eval_filter: True
|
||||
save_res_path: ./output/rec/predicts_seed.txt
|
||||
|
||||
|
||||
|
@ -37,8 +36,8 @@ Optimizer:
|
|||
|
||||
|
||||
Architecture:
|
||||
model_type: seed
|
||||
algorithm: ASTER
|
||||
model_type: rec
|
||||
algorithm: seed
|
||||
Transform:
|
||||
name: STN_ON
|
||||
tps_inputsize: [32, 64]
|
||||
|
@ -76,8 +75,10 @@ Train:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length', 'fast_label'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
@ -95,8 +96,10 @@ Eval:
|
|||
img_mode: BGR
|
||||
channel_first: False
|
||||
- SEEDLabelEncode: # Class handling label
|
||||
- SEEDResize:
|
||||
- RecResizeImg:
|
||||
character_type: en
|
||||
image_shape: [3, 64, 256]
|
||||
padding: False
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
|
|
|
@ -106,7 +106,6 @@ class BaseRecLabelEncode(object):
|
|||
self.max_text_len = max_text_length
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
self.unknown = "UNKNOWN"
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
|
@ -357,7 +356,6 @@ class SEEDLabelEncode(BaseRecLabelEncode):
|
|||
character_type, use_space_char)
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
|
|
@ -88,29 +88,19 @@ class RecResizeImg(object):
|
|||
image_shape,
|
||||
infer_mode=False,
|
||||
character_type='ch',
|
||||
padding=True,
|
||||
**kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
self.character_type = character_type
|
||||
self.padding = padding
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if self.infer_mode and self.character_type == "ch":
|
||||
norm_img = resize_norm_img_chinese(img, self.image_shape)
|
||||
else:
|
||||
norm_img = resize_norm_img(img, self.image_shape)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
||||
class SEEDResize(object):
|
||||
def __init__(self, image_shape, infer_mode=False, **kwargs):
|
||||
self.image_shape = image_shape
|
||||
self.infer_mode = infer_mode
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
norm_img = resize_no_padding_img(img, self.image_shape)
|
||||
norm_img = resize_norm_img(img, self.image_shape, self.padding)
|
||||
data['image'] = norm_img
|
||||
return data
|
||||
|
||||
|
@ -186,16 +176,21 @@ def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
|
|||
return padding_im, resize_shape, pad_shape, valid_ratio
|
||||
|
||||
|
||||
def resize_norm_img(img, image_shape):
|
||||
def resize_norm_img(img, image_shape, padding=True):
|
||||
imgC, imgH, imgW = image_shape
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
if not padding:
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
if image_shape[0] == 1:
|
||||
resized_image = resized_image / 255
|
||||
|
@ -209,17 +204,6 @@ def resize_norm_img(img, image_shape):
|
|||
return padding_im
|
||||
|
||||
|
||||
def resize_no_padding_img(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
resized_image = cv2.resize(
|
||||
img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
return resized_image
|
||||
|
||||
|
||||
def resize_norm_img_chinese(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
# todo: change to 0 and modified image shape
|
||||
|
|
|
@ -17,7 +17,7 @@ __all__ = ['build_transform']
|
|||
|
||||
def build_transform(config):
|
||||
from .tps import TPS
|
||||
from .tps import STN_ON
|
||||
from .stn import STN_ON
|
||||
|
||||
support_dict = ['TPS', 'STN_ON']
|
||||
|
||||
|
|
|
@ -22,6 +22,8 @@ from paddle import nn, ParamAttr
|
|||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
|
||||
|
||||
def conv3x3_block(in_channels, out_channels, stride=1):
|
||||
n = 3 * 3 * out_channels
|
||||
|
@ -106,3 +108,25 @@ class STN(nn.Layer):
|
|||
x = F.sigmoid(x)
|
||||
x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
|
||||
return img_feat, x
|
||||
|
||||
|
||||
class STN_ON(nn.Layer):
|
||||
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
||||
num_control_points, tps_margins, stn_activation):
|
||||
super(STN_ON, self).__init__()
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
self.stn_head = STN(in_channels=in_channels,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation=stn_activation)
|
||||
self.tps_inputsize = tps_inputsize
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
return x
|
||||
|
|
|
@ -22,9 +22,6 @@ from paddle import nn, ParamAttr
|
|||
from paddle.nn import functional as F
|
||||
import numpy as np
|
||||
|
||||
from .tps_spatial_transformer import TPSSpatialTransformer
|
||||
from .stn import STN
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
|
@ -305,25 +302,3 @@ class TPS(nn.Layer):
|
|||
[-1, image.shape[2], image.shape[3], 2])
|
||||
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
|
||||
return batch_I_r
|
||||
|
||||
|
||||
class STN_ON(nn.Layer):
|
||||
def __init__(self, in_channels, tps_inputsize, tps_outputsize,
|
||||
num_control_points, tps_margins, stn_activation):
|
||||
super(STN_ON, self).__init__()
|
||||
self.tps = TPSSpatialTransformer(
|
||||
output_image_size=tuple(tps_outputsize),
|
||||
num_control_points=num_control_points,
|
||||
margins=tuple(tps_margins))
|
||||
self.stn_head = STN(in_channels=in_channels,
|
||||
num_ctrlpoints=num_control_points,
|
||||
activation=stn_activation)
|
||||
self.tps_inputsize = tps_inputsize
|
||||
self.out_channels = in_channels
|
||||
|
||||
def forward(self, image):
|
||||
stn_input = paddle.nn.functional.interpolate(
|
||||
image, self.tps_inputsize, mode="bilinear", align_corners=True)
|
||||
stn_img_feat, ctrl_points = self.stn_head(stn_input)
|
||||
x, _ = self.tps(image, ctrl_points)
|
||||
return x
|
||||
|
|
|
@ -322,7 +322,6 @@ class SEEDLabelDecode(BaseRecLabelDecode):
|
|||
def add_special_char(self, dict_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
dict_character = dict_character
|
||||
dict_character = dict_character + [self.end_str]
|
||||
return dict_character
|
||||
|
||||
|
|
|
@ -11,4 +11,5 @@ opencv-contrib-python==4.4.0.46
|
|||
cython
|
||||
lxml
|
||||
premailer
|
||||
openpyxl
|
||||
openpyxl
|
||||
fasttext==0.9.1
|
|
@ -186,9 +186,8 @@ def train(config,
|
|||
model.train()
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
use_nrtr = config['Architecture']['algorithm'] == "NRTR"
|
||||
use_sar = config['Architecture']['algorithm'] == 'SAR'
|
||||
use_seed = config['Architecture']['algorithm'] == 'SEED'
|
||||
extra_input = config['Architecture'][
|
||||
'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
|
||||
try:
|
||||
model_type = config['Architecture']['model_type']
|
||||
except:
|
||||
|
@ -217,7 +216,7 @@ def train(config,
|
|||
images = batch[0]
|
||||
if use_srn:
|
||||
model_average = True
|
||||
if use_srn or model_type == 'table' or use_nrtr or use_sar or use_seed:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
@ -281,8 +280,7 @@ def train(config,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type,
|
||||
use_srn=use_srn,
|
||||
use_sar=use_sar)
|
||||
extra_input=extra_input)
|
||||
cur_metric_str = 'cur metric, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
|
||||
logger.info(cur_metric_str)
|
||||
|
@ -354,8 +352,7 @@ def eval(model,
|
|||
post_process_class,
|
||||
eval_class,
|
||||
model_type=None,
|
||||
use_srn=False,
|
||||
use_sar=False):
|
||||
extra_input=False):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
|
@ -368,7 +365,7 @@ def eval(model,
|
|||
break
|
||||
images = batch[0]
|
||||
start = time.time()
|
||||
if use_srn or model_type == 'table' or use_sar:
|
||||
if model_type == 'table' or extra_input:
|
||||
preds = model(images, data=batch[1:])
|
||||
else:
|
||||
preds = model(images)
|
||||
|
|
Loading…
Reference in New Issue