trans to paddle-rc
This commit is contained in:
parent
fa675f8954
commit
1ae379198e
|
@ -37,6 +37,7 @@ from ppocr.data.lmdb_dataset import LMDBDateSet
|
|||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
||||
def term_mp(sig_num, frame):
|
||||
""" kill all child processes
|
||||
"""
|
||||
|
@ -45,24 +46,27 @@ def term_mp(sig_num, frame):
|
|||
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
||||
os.killpg(pgid, signal.SIGKILL)
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, term_mp)
|
||||
signal.signal(signal.SIGTERM, term_mp)
|
||||
|
||||
def build_dataloader(config, mode, device):
|
||||
|
||||
def build_dataloader(config, mode, device, logger):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDateSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test."
|
||||
|
||||
dataset = eval(module_name)(config, mode)
|
||||
assert mode in ['Train', 'Eval', 'Test'
|
||||
], "Mode should be Train, Eval or Test."
|
||||
|
||||
dataset = eval(module_name)(config, mode, logger)
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
drop_last = loader_config['drop_last']
|
||||
num_workers = loader_config['num_workers']
|
||||
|
||||
|
||||
if mode == "Train":
|
||||
#Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
|
@ -76,14 +80,13 @@ def build_dataloader(config, mode, device):
|
|||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
drop_last=drop_last)
|
||||
|
||||
drop_last=drop_last)
|
||||
|
||||
data_loader = DataLoader(
|
||||
dataset=dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
places=device,
|
||||
num_workers=num_workers,
|
||||
return_list=True)
|
||||
|
||||
|
||||
return data_loader
|
||||
#return data_loader, _dataset.info_dict
|
|
@ -22,37 +22,26 @@ import lmdb
|
|||
import cv2
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
from ppocr.utils.logging import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class LMDBDateSet(Dataset):
|
||||
def __init__(self, config, mode):
|
||||
def __init__(self, config, mode, logger):
|
||||
super(LMDBDateSet, self).__init__()
|
||||
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
|
||||
self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % data_dir)
|
||||
self.data_idx_order_list = self.dataset_traversal()
|
||||
if self.do_shuffle:
|
||||
np.random.shuffle(self.data_idx_order_list)
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
|
||||
# # for rec
|
||||
# character = ''
|
||||
# for op in self.ops:
|
||||
# if hasattr(op, 'character'):
|
||||
# character = getattr(op, 'character')
|
||||
|
||||
# self.info_dict = {'character': character}
|
||||
|
||||
def load_hierarchical_lmdb_dataset(self, data_dir):
|
||||
lmdb_sets = {}
|
||||
dataset_idx = 0
|
||||
|
@ -71,7 +60,7 @@ class LMDBDateSet(Dataset):
|
|||
"txn":txn, "num_samples":num_samples}
|
||||
dataset_idx += 1
|
||||
return lmdb_sets
|
||||
|
||||
|
||||
def dataset_traversal(self):
|
||||
lmdb_num = len(self.lmdb_sets)
|
||||
total_sample_num = 0
|
||||
|
@ -88,7 +77,7 @@ class LMDBDateSet(Dataset):
|
|||
data_idx_order_list[beg_idx:end_idx, 1] += 1
|
||||
beg_idx = beg_idx + tmp_sample_num
|
||||
return data_idx_order_list
|
||||
|
||||
|
||||
def get_img_data(self, value):
|
||||
"""get_img_data"""
|
||||
if not value:
|
||||
|
@ -110,15 +99,15 @@ class LMDBDateSet(Dataset):
|
|||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
return imgbuf, label
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
lmdb_idx, file_idx = self.data_idx_order_list[idx]
|
||||
lmdb_idx = int(lmdb_idx)
|
||||
file_idx = int(file_idx)
|
||||
sample_info = self.get_lmdb_sample_info(
|
||||
self.lmdb_sets[lmdb_idx]['txn'], file_idx)
|
||||
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
|
||||
file_idx)
|
||||
if sample_info is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
img, label = sample_info
|
||||
data = {'image': img, 'label': label}
|
||||
outs = transform(data, self.ops)
|
||||
|
@ -128,4 +117,3 @@ class LMDBDateSet(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return self.data_idx_order_list.shape[0]
|
||||
|
||||
|
|
|
@ -20,18 +20,17 @@ from paddle.io import Dataset
|
|||
import time
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
from ppocr.utils.logging import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class SimpleDataSet(Dataset):
|
||||
def __init__(self, config, mode):
|
||||
def __init__(self, config, mode, logger):
|
||||
super(SimpleDataSet, self).__init__()
|
||||
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
|
@ -39,19 +38,21 @@ class SimpleDataSet(Dataset):
|
|||
ratio_list = [1.0]
|
||||
else:
|
||||
ratio_list = dataset_config.pop('ratio_list')
|
||||
|
||||
|
||||
assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1."
|
||||
assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines_list, data_num_list = self.get_image_info_list(
|
||||
label_file_list)
|
||||
self.data_idx_order_list = self.dataset_traversal(
|
||||
data_num_list, ratio_list, batch_size)
|
||||
self.shuffle_data_random()
|
||||
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def get_image_info_list(self, file_list):
|
||||
|
@ -65,7 +66,7 @@ class SimpleDataSet(Dataset):
|
|||
data_lines_list.append(lines)
|
||||
data_num_list.append(len(lines))
|
||||
return data_lines_list, data_num_list
|
||||
|
||||
|
||||
def dataset_traversal(self, data_num_list, ratio_list, batch_size):
|
||||
select_num_list = []
|
||||
dataset_num = len(data_num_list)
|
||||
|
@ -87,8 +88,7 @@ class SimpleDataSet(Dataset):
|
|||
cur_index = cur_index_sets[dataset_idx]
|
||||
if cur_index >= data_num_list[dataset_idx]:
|
||||
break
|
||||
data_idx_order_list.append((
|
||||
dataset_idx, cur_index))
|
||||
data_idx_order_list.append((dataset_idx, cur_index))
|
||||
cur_index_sets[dataset_idx] += 1
|
||||
if finish_read_num == dataset_num:
|
||||
break
|
||||
|
@ -99,7 +99,7 @@ class SimpleDataSet(Dataset):
|
|||
for dno in range(len(self.data_lines_list)):
|
||||
random.shuffle(self.data_lines_list[dno])
|
||||
return
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines_list[dataset_idx][file_idx]
|
||||
|
@ -119,4 +119,3 @@ class SimpleDataSet(Dataset):
|
|||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ class ConvBNLayer(nn.Layer):
|
|||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2d(
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
|
@ -183,7 +183,7 @@ class ConvBNLayer(nn.Layer):
|
|||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hard_swish":
|
||||
x = F.hard_swish(x)
|
||||
x = F.activation.hard_swish(x)
|
||||
else:
|
||||
print("The activation function is selected incorrectly.")
|
||||
exit()
|
||||
|
@ -242,16 +242,15 @@ class ResidualUnit(nn.Layer):
|
|||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.elementwise_add(inputs, x)
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.Pool2D(
|
||||
pool_type="avg", global_pooling=True, use_cudnn=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
|
@ -259,7 +258,7 @@ class SEModule(nn.Layer):
|
|||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
self.conv2 = nn.Conv2d(
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
|
@ -273,5 +272,5 @@ class SEModule(nn.Layer):
|
|||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.hard_sigmoid(outputs)
|
||||
outputs = F.activation.hard_sigmoid(outputs)
|
||||
return inputs * outputs
|
|
@ -127,7 +127,7 @@ class MobileNetV3(nn.Layer):
|
|||
act='hard_swish',
|
||||
name='conv_last')
|
||||
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = make_divisible(scale * cls_ch_squeeze)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -33,7 +33,7 @@ def get_bias_attr(k, name):
|
|||
class Head(nn.Layer):
|
||||
def __init__(self, in_channels, name_list):
|
||||
super(Head, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // 4,
|
||||
kernel_size=3,
|
||||
|
@ -51,14 +51,14 @@ class Head(nn.Layer):
|
|||
moving_mean_name=name_list[1] + '.w_1',
|
||||
moving_variance_name=name_list[1] + '.w_2',
|
||||
act='relu')
|
||||
self.conv2 = nn.ConvTranspose2d(
|
||||
self.conv2 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
out_channels=in_channels // 4,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[2] + '.w_0',
|
||||
initializer=paddle.nn.initializer.MSRA(uniform=False)),
|
||||
initializer=paddle.nn.initializer.KaimingNormal()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2"))
|
||||
self.conv_bn2 = nn.BatchNorm(
|
||||
num_channels=in_channels // 4,
|
||||
|
@ -71,14 +71,14 @@ class Head(nn.Layer):
|
|||
moving_mean_name=name_list[3] + '.w_1',
|
||||
moving_variance_name=name_list[3] + '.w_2',
|
||||
act="relu")
|
||||
self.conv3 = nn.ConvTranspose2d(
|
||||
self.conv3 = nn.Conv2DTranspose(
|
||||
in_channels=in_channels // 4,
|
||||
out_channels=1,
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
weight_attr=ParamAttr(
|
||||
name=name_list[4] + '.w_0',
|
||||
initializer=paddle.nn.initializer.MSRA(uniform=False)),
|
||||
initializer=paddle.nn.initializer.KaimingNormal()),
|
||||
bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"),
|
||||
)
|
||||
|
||||
|
|
|
@ -26,37 +26,37 @@ class DBFPN(nn.Layer):
|
|||
def __init__(self, in_channels, out_channels, **kwargs):
|
||||
super(DBFPN, self).__init__()
|
||||
self.out_channels = out_channels
|
||||
weight_attr = paddle.nn.initializer.MSRA(uniform=False)
|
||||
weight_attr = paddle.nn.initializer.KaimingNormal()
|
||||
|
||||
self.in2_conv = nn.Conv2d(
|
||||
self.in2_conv = nn.Conv2D(
|
||||
in_channels=in_channels[0],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_51.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in3_conv = nn.Conv2d(
|
||||
self.in3_conv = nn.Conv2D(
|
||||
in_channels=in_channels[1],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_50.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in4_conv = nn.Conv2d(
|
||||
self.in4_conv = nn.Conv2D(
|
||||
in_channels=in_channels[2],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_49.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.in5_conv = nn.Conv2d(
|
||||
self.in5_conv = nn.Conv2D(
|
||||
in_channels=in_channels[3],
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
weight_attr=ParamAttr(
|
||||
name='conv2d_48.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p5_conv = nn.Conv2d(
|
||||
self.p5_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
|
@ -64,7 +64,7 @@ class DBFPN(nn.Layer):
|
|||
weight_attr=ParamAttr(
|
||||
name='conv2d_52.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p4_conv = nn.Conv2d(
|
||||
self.p4_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
|
@ -72,7 +72,7 @@ class DBFPN(nn.Layer):
|
|||
weight_attr=ParamAttr(
|
||||
name='conv2d_53.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p3_conv = nn.Conv2d(
|
||||
self.p3_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
|
@ -80,7 +80,7 @@ class DBFPN(nn.Layer):
|
|||
weight_attr=ParamAttr(
|
||||
name='conv2d_54.w_0', initializer=weight_attr),
|
||||
bias_attr=False)
|
||||
self.p2_conv = nn.Conv2d(
|
||||
self.p2_conv = nn.Conv2D(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels // 4,
|
||||
kernel_size=3,
|
||||
|
@ -97,17 +97,17 @@ class DBFPN(nn.Layer):
|
|||
in3 = self.in3_conv(c3)
|
||||
in2 = self.in2_conv(c2)
|
||||
|
||||
out4 = in4 + F.resize_nearest(in5, scale=2) # 1/16
|
||||
out3 = in3 + F.resize_nearest(out4, scale=2) # 1/8
|
||||
out2 = in2 + F.resize_nearest(out3, scale=2) # 1/4
|
||||
out4 = in4 + F.upsample(in5, scale_factor=2, mode="nearest") # 1/16
|
||||
out3 = in3 + F.upsample(out4, scale_factor=2, mode="nearest") # 1/8
|
||||
out2 = in2 + F.upsample(out3, scale_factor=2, mode="nearest") # 1/4
|
||||
|
||||
p5 = self.p5_conv(in5)
|
||||
p4 = self.p4_conv(out4)
|
||||
p3 = self.p3_conv(out3)
|
||||
p2 = self.p2_conv(out2)
|
||||
p5 = F.resize_nearest(p5, scale=8)
|
||||
p4 = F.resize_nearest(p4, scale=4)
|
||||
p3 = F.resize_nearest(p3, scale=2)
|
||||
p5 = F.upsample(p5, scale_factor=8, mode="nearest")
|
||||
p4 = F.upsample(p4, scale_factor=4, mode="nearest")
|
||||
p3 = F.upsample(p3, scale_factor=2, mode="nearest")
|
||||
|
||||
fuse = paddle.concat([p5, p4, p3, p2], axis=1)
|
||||
return fuse
|
||||
|
|
|
@ -50,9 +50,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters):
|
|||
|
||||
# step3 build optimizer
|
||||
optim_name = config.pop('name')
|
||||
# Regularization is invalid. The bug will be fixed in paddle-rc. The param is
|
||||
# weight_decay.
|
||||
optim = getattr(optimizer, optim_name)(learning_rate=lr,
|
||||
regularization=reg,
|
||||
weight_decay=reg,
|
||||
**config)
|
||||
return optim(parameters), lr
|
||||
|
|
|
@ -17,7 +17,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from paddle.optimizer import lr_scheduler
|
||||
from paddle.optimizer import lr as lr_scheduler
|
||||
|
||||
|
||||
class Linear(object):
|
||||
|
|
|
@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO):
|
|||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
|
|
|
@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger):
|
|||
raise OSError('Failed to mkdir {}'.format(path))
|
||||
|
||||
|
||||
def load_dygraph_pretrain(
|
||||
model,
|
||||
logger,
|
||||
path=None,
|
||||
load_static_weights=False):
|
||||
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
||||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
if load_static_weights:
|
||||
pre_state_dict = paddle.io.load_program_state(path)
|
||||
pre_state_dict = paddle.static.load_program_state(path)
|
||||
param_state_dict = {}
|
||||
model_dict = model.state_dict()
|
||||
for key in model_dict.keys():
|
||||
|
@ -113,15 +109,11 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
|
|||
if not isinstance(pretrained_model, list):
|
||||
pretrained_model = [pretrained_model]
|
||||
if not isinstance(load_static_weights, list):
|
||||
load_static_weights = [load_static_weights] * len(
|
||||
pretrained_model)
|
||||
load_static_weights = [load_static_weights] * len(pretrained_model)
|
||||
for idx, pretrained in enumerate(pretrained_model):
|
||||
load_static = load_static_weights[idx]
|
||||
load_dygraph_pretrain(
|
||||
model,
|
||||
logger,
|
||||
path=pretrained,
|
||||
load_static_weights=load_static)
|
||||
model, logger, path=pretrained, load_static_weights=load_static)
|
||||
logger.info("load pretrained model from {}".format(
|
||||
pretrained_model))
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import paddle
|
||||
from paddle.jit import to_static
|
||||
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.utils.save_load import init_model
|
||||
from tools.program import load_config
|
||||
from tools.program import merge_config
|
||||
|
||||
|
||||
def parse_args():
|
||||
def str2bool(v):
|
||||
return v.lower() in ("true", "t", "1")
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config", help="configuration file to use")
|
||||
parser.add_argument(
|
||||
"-o", "--output_path", type=str, default='./output/infer/')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class Model(paddle.nn.Layer):
|
||||
def __init__(self, model):
|
||||
super(Model, self).__init__()
|
||||
self.pre_model = model
|
||||
|
||||
# Please modify the 'shape' according to actual needs
|
||||
@to_static(input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None, 3, 32, None], dtype='float32')
|
||||
])
|
||||
def forward(self, inputs):
|
||||
x = self.pre_model(inputs)
|
||||
return x
|
||||
|
||||
|
||||
def main():
|
||||
FLAGS = parse_args()
|
||||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
config['Global'])
|
||||
|
||||
# build model
|
||||
#for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
init_model(config, model, logger)
|
||||
model.eval()
|
||||
|
||||
model = Model(model)
|
||||
paddle.jit.save(model, FLAGS.output_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -33,6 +33,7 @@ from ppocr.utils.logging import get_logger
|
|||
from ppocr.data import build_dataloader
|
||||
import numpy as np
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
|
@ -185,7 +186,7 @@ def train(config,
|
|||
for epoch in range(start_epoch, epoch_num):
|
||||
if epoch > 0:
|
||||
train_loader = build_dataloader(config, 'Train', device)
|
||||
|
||||
|
||||
for idx, batch in enumerate(train_dataloader):
|
||||
if idx >= len(train_dataloader):
|
||||
break
|
||||
|
@ -196,12 +197,7 @@ def train(config,
|
|||
preds = model(images)
|
||||
loss = loss_class(preds, batch)
|
||||
avg_loss = loss['loss']
|
||||
if config['Global']['distributed']:
|
||||
avg_loss = model.scale_loss(avg_loss)
|
||||
avg_loss.backward()
|
||||
model.apply_collective_grads()
|
||||
else:
|
||||
avg_loss.backward()
|
||||
avg_loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.clear_grad()
|
||||
if not isinstance(lr_scheduler, float):
|
||||
|
@ -227,7 +223,8 @@ def train(config,
|
|||
vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step)
|
||||
vdl_writer.add_scalar('TRAIN/lr', lr, global_step)
|
||||
|
||||
if global_step > 0 and global_step % print_batch_step == 0:
|
||||
if dist.get_rank(
|
||||
) == 0 and global_step > 0 and global_step % print_batch_step == 0:
|
||||
logs = train_stats.log()
|
||||
strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format(
|
||||
epoch, epoch_num, global_step, logs, train_batch_elapse)
|
||||
|
@ -235,8 +232,8 @@ def train(config,
|
|||
# eval
|
||||
if global_step > start_eval_step and \
|
||||
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
|
||||
cur_metirc = eval(model, valid_dataloader,
|
||||
post_process_class, eval_class, logger, print_batch_step)
|
||||
cur_metirc = eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, logger, print_batch_step)
|
||||
cur_metirc_str = 'cur metirc, {}'.format(', '.join(
|
||||
['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
|
||||
logger.info(cur_metirc_str)
|
||||
|
@ -298,18 +295,17 @@ def train(config,
|
|||
return
|
||||
|
||||
|
||||
def eval(model, valid_dataloader,
|
||||
post_process_class, eval_class,
|
||||
logger, print_batch_step):
|
||||
def eval(model, valid_dataloader, post_process_class, eval_class, logger,
|
||||
print_batch_step):
|
||||
model.eval()
|
||||
with paddle.no_grad():
|
||||
total_frame = 0.0
|
||||
total_time = 0.0
|
||||
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
|
||||
# pbar = tqdm(total=len(valid_dataloader), desc='eval model:')
|
||||
for idx, batch in enumerate(valid_dataloader):
|
||||
if idx >= len(valid_dataloader):
|
||||
break
|
||||
images = paddle.to_variable(batch[0])
|
||||
images = paddle.to_tensor(batch[0])
|
||||
start = time.time()
|
||||
preds = model(images)
|
||||
|
||||
|
@ -319,13 +315,14 @@ def eval(model, valid_dataloader,
|
|||
total_time += time.time() - start
|
||||
# Evaluate the results of the current batch
|
||||
eval_class(post_result, batch)
|
||||
# pbar.update(1)
|
||||
# pbar.update(1)
|
||||
total_frame += len(images)
|
||||
if idx % print_batch_step == 0:
|
||||
if idx % print_batch_step == 0 and dist.get_rank() == 0:
|
||||
logger.info('tackling images for eval: {}/{}'.format(
|
||||
idx, len(valid_dataloader)))
|
||||
# Get final metirc,eg. acc or hmean
|
||||
metirc = eval_class.get_metric()
|
||||
|
||||
# pbar.close()
|
||||
model.train()
|
||||
metirc['fps'] = total_frame / total_time
|
||||
|
@ -348,16 +345,15 @@ def preprocess():
|
|||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
paddle.disable_static(device)
|
||||
|
||||
# save_config
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
os.makedirs(save_model_dir, exist_ok=True)
|
||||
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
|
||||
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
|
||||
|
||||
|
||||
logger = get_logger(log_file='{}/train.log'.format(save_model_dir))
|
||||
if config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
|
|
|
@ -27,9 +27,8 @@ import yaml
|
|||
import paddle
|
||||
import paddle.distributed as dist
|
||||
|
||||
paddle.manual_seed(2)
|
||||
paddle.seed(2)
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
|
@ -49,18 +48,18 @@ def main(config, device, logger, vdl_writer):
|
|||
dist.init_parallel_env()
|
||||
|
||||
global_config = config['Global']
|
||||
|
||||
|
||||
# build dataloader
|
||||
train_dataloader = build_dataloader(config, 'Train', device)
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
if config['Eval']:
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device)
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
else:
|
||||
valid_dataloader = None
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(
|
||||
config['PostProcess'], global_config)
|
||||
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
#for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
|
@ -72,38 +71,29 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
# build loss
|
||||
loss_class = build_loss(config['Loss'])
|
||||
|
||||
|
||||
# build optim
|
||||
optimizer, lr_scheduler = build_optimizer(config['Optimizer'],
|
||||
optimizer, lr_scheduler = build_optimizer(
|
||||
config['Optimizer'],
|
||||
epochs=config['Global']['epoch_num'],
|
||||
step_each_epoch=len(train_dataloader),
|
||||
parameters=model.parameters())
|
||||
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
|
||||
# start train
|
||||
program.train(config,
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
device,
|
||||
model,
|
||||
loss_class,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
post_process_class,
|
||||
eval_class,
|
||||
pre_best_model_dict,
|
||||
logger,
|
||||
vdl_writer)
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
|
||||
|
||||
def test_reader(config, device, logger):
|
||||
loader = build_dataloader(config, 'Train', device)
|
||||
# loader = build_dataloader(config, 'Eval', device)
|
||||
# loader = build_dataloader(config, 'Eval', device)
|
||||
import time
|
||||
starttime = time.time()
|
||||
count = 0
|
||||
|
@ -113,11 +103,13 @@ def test_reader(config, device, logger):
|
|||
if count % 1 == 0:
|
||||
batch_time = time.time() - starttime
|
||||
starttime = time.time()
|
||||
logger.info("reader: {}, {}, {}".format(count, len(data), batch_time))
|
||||
logger.info("reader: {}, {}, {}".format(count,
|
||||
len(data), batch_time))
|
||||
except Exception as e:
|
||||
logger.info(e)
|
||||
logger.info("finish reader: {}, Success!".format(count))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main(config, device, logger, vdl_writer)
|
||||
|
|
Loading…
Reference in New Issue