diff --git a/configs/det/det_r50_vd_db.yml b/configs/det/bak/det_r50_vd_db.yml similarity index 100% rename from configs/det/det_r50_vd_db.yml rename to configs/det/bak/det_r50_vd_db.yml diff --git a/configs/rec/rec_mv3_none_bilstm_ctc_simple.yml b/configs/rec/bak/rec_mv3_none_bilstm_ctc_simple.yml similarity index 100% rename from configs/rec/rec_mv3_none_bilstm_ctc_simple.yml rename to configs/rec/bak/rec_mv3_none_bilstm_ctc_simple.yml diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/bak/rec_r34_vd_none_bilstm_ctc.yml similarity index 100% rename from configs/rec/rec_r34_vd_none_bilstm_ctc.yml rename to configs/rec/bak/rec_r34_vd_none_bilstm_ctc.yml diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/bak/rec_r34_vd_none_none_ctc.yml similarity index 100% rename from configs/rec/rec_r34_vd_none_none_ctc.yml rename to configs/rec/bak/rec_r34_vd_none_none_ctc.yml diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index f1b73c07..2f95b377 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -37,6 +37,7 @@ from ppocr.data.lmdb_dataset import LMDBDateSet __all__ = ['build_dataloader', 'transform', 'create_operators'] + def term_mp(sig_num, frame): """ kill all child processes """ @@ -45,24 +46,27 @@ def term_mp(sig_num, frame): print("main proc {} exit, kill process group " "{}".format(pid, pgid)) os.killpg(pgid, signal.SIGKILL) + signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGTERM, term_mp) -def build_dataloader(config, mode, device): + +def build_dataloader(config, mode, device, logger): config = copy.deepcopy(config) - + support_dict = ['SimpleDataSet', 'LMDBDateSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) - assert mode in ['Train', 'Eval', 'Test'], "Mode should be Train, Eval or Test." - - dataset = eval(module_name)(config, mode) + assert mode in ['Train', 'Eval', 'Test' + ], "Mode should be Train, Eval or Test." + + dataset = eval(module_name)(config, mode, logger) loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] drop_last = loader_config['drop_last'] num_workers = loader_config['num_workers'] - + if mode == "Train": #Distribute data to multiple cards batch_sampler = DistributedBatchSampler( @@ -76,14 +80,13 @@ def build_dataloader(config, mode, device): dataset=dataset, batch_size=batch_size, shuffle=False, - drop_last=drop_last) - + drop_last=drop_last) + data_loader = DataLoader( dataset=dataset, batch_sampler=batch_sampler, places=device, num_workers=num_workers, return_list=True) - + return data_loader - #return data_loader, _dataset.info_dict \ No newline at end of file diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index 4cd48674..ffa05228 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -22,37 +22,26 @@ import lmdb import cv2 from .imaug import transform, create_operators -from ppocr.utils.logging import get_logger -logger = get_logger() + class LMDBDateSet(Dataset): - def __init__(self, config, mode): + def __init__(self, config, mode, logger): super(LMDBDateSet, self).__init__() - + global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] - + self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) - logger.info("Initialize indexs of datasets:%s" % data_dir) self.data_idx_order_list = self.dataset_traversal() if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) - -# # for rec -# character = '' -# for op in self.ops: -# if hasattr(op, 'character'): -# character = getattr(op, 'character') - -# self.info_dict = {'character': character} - def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 @@ -71,7 +60,7 @@ class LMDBDateSet(Dataset): "txn":txn, "num_samples":num_samples} dataset_idx += 1 return lmdb_sets - + def dataset_traversal(self): lmdb_num = len(self.lmdb_sets) total_sample_num = 0 @@ -88,7 +77,7 @@ class LMDBDateSet(Dataset): data_idx_order_list[beg_idx:end_idx, 1] += 1 beg_idx = beg_idx + tmp_sample_num return data_idx_order_list - + def get_img_data(self, value): """get_img_data""" if not value: @@ -110,15 +99,15 @@ class LMDBDateSet(Dataset): img_key = 'image-%09d'.encode() % index imgbuf = txn.get(img_key) return imgbuf, label - + def __getitem__(self, idx): lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) - sample_info = self.get_lmdb_sample_info( - self.lmdb_sets[lmdb_idx]['txn'], file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) if sample_info is None: - return self.__getitem__(np.random.randint(self.__len__())) + return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} outs = transform(data, self.ops) @@ -128,4 +117,3 @@ class LMDBDateSet(Dataset): def __len__(self): return self.data_idx_order_list.shape[0] - diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index fbc03c51..e9813cdc 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -20,18 +20,17 @@ from paddle.io import Dataset import time from .imaug import transform, create_operators -from ppocr.utils.logging import get_logger -logger = get_logger() + class SimpleDataSet(Dataset): - def __init__(self, config, mode): + def __init__(self, config, mode, logger): super(SimpleDataSet, self).__init__() - + global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] - + self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) @@ -39,19 +38,21 @@ class SimpleDataSet(Dataset): ratio_list = [1.0] else: ratio_list = dataset_config.pop('ratio_list') - + assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1." - assert len(ratio_list) == data_source_num, "The length of ratio_list should be the same as the file_list." + assert len( + ratio_list + ) == data_source_num, "The length of ratio_list should be the same as the file_list." self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] - + logger.info("Initialize indexs of datasets:%s" % label_file_list) self.data_lines_list, data_num_list = self.get_image_info_list( label_file_list) self.data_idx_order_list = self.dataset_traversal( data_num_list, ratio_list, batch_size) self.shuffle_data_random() - + self.ops = create_operators(dataset_config['transforms'], global_config) def get_image_info_list(self, file_list): @@ -65,7 +66,7 @@ class SimpleDataSet(Dataset): data_lines_list.append(lines) data_num_list.append(len(lines)) return data_lines_list, data_num_list - + def dataset_traversal(self, data_num_list, ratio_list, batch_size): select_num_list = [] dataset_num = len(data_num_list) @@ -87,8 +88,7 @@ class SimpleDataSet(Dataset): cur_index = cur_index_sets[dataset_idx] if cur_index >= data_num_list[dataset_idx]: break - data_idx_order_list.append(( - dataset_idx, cur_index)) + data_idx_order_list.append((dataset_idx, cur_index)) cur_index_sets[dataset_idx] += 1 if finish_read_num == dataset_num: break @@ -99,7 +99,7 @@ class SimpleDataSet(Dataset): for dno in range(len(self.data_lines_list)): random.shuffle(self.data_lines_list[dno]) return - + def __getitem__(self, idx): dataset_idx, file_idx = self.data_idx_order_list[idx] data_line = self.data_lines_list[dataset_idx][file_idx] @@ -119,4 +119,3 @@ class SimpleDataSet(Dataset): def __len__(self): return len(self.data_idx_order_list) - diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index ad4065a7..017dce2f 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -158,7 +158,7 @@ class ConvBNLayer(nn.Layer): super(ConvBNLayer, self).__init__() self.if_act = if_act self.act = act - self.conv = nn.Conv2d( + self.conv = nn.Conv2D( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -183,7 +183,7 @@ class ConvBNLayer(nn.Layer): if self.act == "relu": x = F.relu(x) elif self.act == "hard_swish": - x = F.hard_swish(x) + x = F.activation.hard_swish(x) else: print("The activation function is selected incorrectly.") exit() @@ -242,16 +242,15 @@ class ResidualUnit(nn.Layer): x = self.mid_se(x) x = self.linear_conv(x) if self.if_shortcut: - x = paddle.elementwise_add(inputs, x) + x = paddle.add(inputs, x) return x class SEModule(nn.Layer): def __init__(self, in_channels, reduction=4, name=""): super(SEModule, self).__init__() - self.avg_pool = nn.Pool2D( - pool_type="avg", global_pooling=True, use_cudnn=False) - self.conv1 = nn.Conv2d( + self.avg_pool = nn.AdaptiveAvgPool2D(1) + self.conv1 = nn.Conv2D( in_channels=in_channels, out_channels=in_channels // reduction, kernel_size=1, @@ -259,7 +258,7 @@ class SEModule(nn.Layer): padding=0, weight_attr=ParamAttr(name=name + "_1_weights"), bias_attr=ParamAttr(name=name + "_1_offset")) - self.conv2 = nn.Conv2d( + self.conv2 = nn.Conv2D( in_channels=in_channels // reduction, out_channels=in_channels, kernel_size=1, @@ -273,5 +272,5 @@ class SEModule(nn.Layer): outputs = self.conv1(outputs) outputs = F.relu(outputs) outputs = self.conv2(outputs) - outputs = F.hard_sigmoid(outputs) + outputs = F.activation.hard_sigmoid(outputs) return inputs * outputs \ No newline at end of file diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py index bcba8600..91e57ffa 100644 --- a/ppocr/modeling/backbones/rec_mobilenet_v3.py +++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py @@ -127,7 +127,7 @@ class MobileNetV3(nn.Layer): act='hard_swish', name='conv_last') - self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) + self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0) self.out_channels = make_divisible(scale * cls_ch_squeeze) def forward(self, x): diff --git a/ppocr/modeling/heads/det_db_head.py b/ppocr/modeling/heads/det_db_head.py index 85149abd..23789910 100644 --- a/ppocr/modeling/heads/det_db_head.py +++ b/ppocr/modeling/heads/det_db_head.py @@ -33,7 +33,7 @@ def get_bias_attr(k, name): class Head(nn.Layer): def __init__(self, in_channels, name_list): super(Head, self).__init__() - self.conv1 = nn.Conv2d( + self.conv1 = nn.Conv2D( in_channels=in_channels, out_channels=in_channels // 4, kernel_size=3, @@ -51,14 +51,14 @@ class Head(nn.Layer): moving_mean_name=name_list[1] + '.w_1', moving_variance_name=name_list[1] + '.w_2', act='relu') - self.conv2 = nn.ConvTranspose2d( + self.conv2 = nn.Conv2DTranspose( in_channels=in_channels // 4, out_channels=in_channels // 4, kernel_size=2, stride=2, weight_attr=ParamAttr( name=name_list[2] + '.w_0', - initializer=paddle.nn.initializer.MSRA(uniform=False)), + initializer=paddle.nn.initializer.KaimingNormal()), bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv2")) self.conv_bn2 = nn.BatchNorm( num_channels=in_channels // 4, @@ -71,14 +71,14 @@ class Head(nn.Layer): moving_mean_name=name_list[3] + '.w_1', moving_variance_name=name_list[3] + '.w_2', act="relu") - self.conv3 = nn.ConvTranspose2d( + self.conv3 = nn.Conv2DTranspose( in_channels=in_channels // 4, out_channels=1, kernel_size=2, stride=2, weight_attr=ParamAttr( name=name_list[4] + '.w_0', - initializer=paddle.nn.initializer.MSRA(uniform=False)), + initializer=paddle.nn.initializer.KaimingNormal()), bias_attr=get_bias_attr(in_channels // 4, name_list[-1] + "conv3"), ) diff --git a/ppocr/modeling/necks/db_fpn.py b/ppocr/modeling/necks/db_fpn.py index 8adabbd7..dbe482b4 100644 --- a/ppocr/modeling/necks/db_fpn.py +++ b/ppocr/modeling/necks/db_fpn.py @@ -26,37 +26,37 @@ class DBFPN(nn.Layer): def __init__(self, in_channels, out_channels, **kwargs): super(DBFPN, self).__init__() self.out_channels = out_channels - weight_attr = paddle.nn.initializer.MSRA(uniform=False) + weight_attr = paddle.nn.initializer.KaimingNormal() - self.in2_conv = nn.Conv2d( + self.in2_conv = nn.Conv2D( in_channels=in_channels[0], out_channels=self.out_channels, kernel_size=1, weight_attr=ParamAttr( name='conv2d_51.w_0', initializer=weight_attr), bias_attr=False) - self.in3_conv = nn.Conv2d( + self.in3_conv = nn.Conv2D( in_channels=in_channels[1], out_channels=self.out_channels, kernel_size=1, weight_attr=ParamAttr( name='conv2d_50.w_0', initializer=weight_attr), bias_attr=False) - self.in4_conv = nn.Conv2d( + self.in4_conv = nn.Conv2D( in_channels=in_channels[2], out_channels=self.out_channels, kernel_size=1, weight_attr=ParamAttr( name='conv2d_49.w_0', initializer=weight_attr), bias_attr=False) - self.in5_conv = nn.Conv2d( + self.in5_conv = nn.Conv2D( in_channels=in_channels[3], out_channels=self.out_channels, kernel_size=1, weight_attr=ParamAttr( name='conv2d_48.w_0', initializer=weight_attr), bias_attr=False) - self.p5_conv = nn.Conv2d( + self.p5_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, @@ -64,7 +64,7 @@ class DBFPN(nn.Layer): weight_attr=ParamAttr( name='conv2d_52.w_0', initializer=weight_attr), bias_attr=False) - self.p4_conv = nn.Conv2d( + self.p4_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, @@ -72,7 +72,7 @@ class DBFPN(nn.Layer): weight_attr=ParamAttr( name='conv2d_53.w_0', initializer=weight_attr), bias_attr=False) - self.p3_conv = nn.Conv2d( + self.p3_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, @@ -80,7 +80,7 @@ class DBFPN(nn.Layer): weight_attr=ParamAttr( name='conv2d_54.w_0', initializer=weight_attr), bias_attr=False) - self.p2_conv = nn.Conv2d( + self.p2_conv = nn.Conv2D( in_channels=self.out_channels, out_channels=self.out_channels // 4, kernel_size=3, @@ -97,17 +97,17 @@ class DBFPN(nn.Layer): in3 = self.in3_conv(c3) in2 = self.in2_conv(c2) - out4 = in4 + F.resize_nearest(in5, scale=2) # 1/16 - out3 = in3 + F.resize_nearest(out4, scale=2) # 1/8 - out2 = in2 + F.resize_nearest(out3, scale=2) # 1/4 + out4 = in4 + F.upsample(in5, scale_factor=2, mode="nearest") # 1/16 + out3 = in3 + F.upsample(out4, scale_factor=2, mode="nearest") # 1/8 + out2 = in2 + F.upsample(out3, scale_factor=2, mode="nearest") # 1/4 p5 = self.p5_conv(in5) p4 = self.p4_conv(out4) p3 = self.p3_conv(out3) p2 = self.p2_conv(out2) - p5 = F.resize_nearest(p5, scale=8) - p4 = F.resize_nearest(p4, scale=4) - p3 = F.resize_nearest(p3, scale=2) + p5 = F.upsample(p5, scale_factor=8, mode="nearest") + p4 = F.upsample(p4, scale_factor=4, mode="nearest") + p3 = F.upsample(p3, scale_factor=2, mode="nearest") fuse = paddle.concat([p5, p4, p3, p2], axis=1) return fuse diff --git a/ppocr/optimizer/__init__.py b/ppocr/optimizer/__init__.py index 740fc21e..72366e50 100644 --- a/ppocr/optimizer/__init__.py +++ b/ppocr/optimizer/__init__.py @@ -50,9 +50,7 @@ def build_optimizer(config, epochs, step_each_epoch, parameters): # step3 build optimizer optim_name = config.pop('name') - # Regularization is invalid. The bug will be fixed in paddle-rc. The param is - # weight_decay. optim = getattr(optimizer, optim_name)(learning_rate=lr, - regularization=reg, + weight_decay=reg, **config) return optim(parameters), lr diff --git a/ppocr/optimizer/learning_rate.py b/ppocr/optimizer/learning_rate.py index 5b86e846..518e0eef 100644 --- a/ppocr/optimizer/learning_rate.py +++ b/ppocr/optimizer/learning_rate.py @@ -17,7 +17,7 @@ from __future__ import division from __future__ import print_function from __future__ import unicode_literals -from paddle.optimizer import lr_scheduler +from paddle.optimizer import lr as lr_scheduler class Linear(object): diff --git a/ppocr/utils/logging.py b/ppocr/utils/logging.py index 150538a7..e3fa6b23 100644 --- a/ppocr/utils/logging.py +++ b/ppocr/utils/logging.py @@ -52,7 +52,6 @@ def get_logger(name='ppocr', log_file=None, log_level=logging.INFO): stream_handler = logging.StreamHandler(stream=sys.stdout) stream_handler.setFormatter(formatter) logger.addHandler(stream_handler) - if log_file is not None and dist.get_rank() == 0: log_file_folder = os.path.split(log_file)[0] os.makedirs(log_file_folder, exist_ok=True) diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 1ef20331..c6d20651 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -42,16 +42,12 @@ def _mkdir_if_not_exist(path, logger): raise OSError('Failed to mkdir {}'.format(path)) -def load_dygraph_pretrain( - model, - logger, - path=None, - load_static_weights=False): +def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) if load_static_weights: - pre_state_dict = paddle.io.load_program_state(path) + pre_state_dict = paddle.static.load_program_state(path) param_state_dict = {} model_dict = model.state_dict() for key in model_dict.keys(): @@ -113,15 +109,11 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None): if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] if not isinstance(load_static_weights, list): - load_static_weights = [load_static_weights] * len( - pretrained_model) + load_static_weights = [load_static_weights] * len(pretrained_model) for idx, pretrained in enumerate(pretrained_model): load_static = load_static_weights[idx] load_dygraph_pretrain( - model, - logger, - path=pretrained, - load_static_weights=load_static) + model, logger, path=pretrained, load_static_weights=load_static) logger.info("load pretrained model from {}".format( pretrained_model)) else: diff --git a/tools/export_model.py b/tools/export_model.py new file mode 100755 index 00000000..60c05725 --- /dev/null +++ b/tools/export_model.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import paddle +from paddle.jit import to_static + +from ppocr.modeling.architectures import build_model +from ppocr.postprocess import build_post_process +from ppocr.utils.save_load import init_model +from tools.program import load_config +from tools.program import merge_config + + +def parse_args(): + def str2bool(v): + return v.lower() in ("true", "t", "1") + + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", help="configuration file to use") + parser.add_argument( + "-o", "--output_path", type=str, default='./output/infer/') + return parser.parse_args() + + +class Model(paddle.nn.Layer): + def __init__(self, model): + super(Model, self).__init__() + self.pre_model = model + + # Please modify the 'shape' according to actual needs + @to_static(input_spec=[ + paddle.static.InputSpec( + shape=[None, 3, 32, None], dtype='float32') + ]) + def forward(self, inputs): + x = self.pre_model(inputs) + return x + + +def main(): + FLAGS = parse_args() + config = load_config(FLAGS.config) + merge_config(FLAGS.opt) + + # build post process + post_process_class = build_post_process(config['PostProcess'], + config['Global']) + + # build model + #for rec algorithm + if hasattr(post_process_class, 'character'): + char_num = len(getattr(post_process_class, 'character')) + config['Architecture']["Head"]['out_channels'] = char_num + model = build_model(config['Architecture']) + init_model(config, model, logger) + model.eval() + + model = Model(model) + paddle.jit.save(model, FLAGS.output_path) + + +if __name__ == "__main__": + main() diff --git a/tools/program.py b/tools/program.py index da28005a..696700c0 100755 --- a/tools/program.py +++ b/tools/program.py @@ -33,6 +33,7 @@ from ppocr.utils.logging import get_logger from ppocr.data import build_dataloader import numpy as np + class ArgsParser(ArgumentParser): def __init__(self): super(ArgsParser, self).__init__( @@ -185,7 +186,7 @@ def train(config, for epoch in range(start_epoch, epoch_num): if epoch > 0: train_loader = build_dataloader(config, 'Train', device) - + for idx, batch in enumerate(train_dataloader): if idx >= len(train_dataloader): break @@ -196,12 +197,7 @@ def train(config, preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] - if config['Global']['distributed']: - avg_loss = model.scale_loss(avg_loss) - avg_loss.backward() - model.apply_collective_grads() - else: - avg_loss.backward() + avg_loss.backward() optimizer.step() optimizer.clear_grad() if not isinstance(lr_scheduler, float): @@ -227,7 +223,8 @@ def train(config, vdl_writer.add_scalar('TRAIN/{}'.format(k), v, global_step) vdl_writer.add_scalar('TRAIN/lr', lr, global_step) - if global_step > 0 and global_step % print_batch_step == 0: + if dist.get_rank( + ) == 0 and global_step > 0 and global_step % print_batch_step == 0: logs = train_stats.log() strs = 'epoch: [{}/{}], iter: {}, {}, time: {:.3f}'.format( epoch, epoch_num, global_step, logs, train_batch_elapse) @@ -235,8 +232,8 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - cur_metirc = eval(model, valid_dataloader, - post_process_class, eval_class, logger, print_batch_step) + cur_metirc = eval(model, valid_dataloader, post_process_class, + eval_class, logger, print_batch_step) cur_metirc_str = 'cur metirc, {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in cur_metirc.items()])) logger.info(cur_metirc_str) @@ -298,18 +295,17 @@ def train(config, return -def eval(model, valid_dataloader, - post_process_class, eval_class, - logger, print_batch_step): +def eval(model, valid_dataloader, post_process_class, eval_class, logger, + print_batch_step): model.eval() with paddle.no_grad(): total_frame = 0.0 total_time = 0.0 -# pbar = tqdm(total=len(valid_dataloader), desc='eval model:') + # pbar = tqdm(total=len(valid_dataloader), desc='eval model:') for idx, batch in enumerate(valid_dataloader): if idx >= len(valid_dataloader): break - images = paddle.to_variable(batch[0]) + images = paddle.to_tensor(batch[0]) start = time.time() preds = model(images) @@ -319,13 +315,14 @@ def eval(model, valid_dataloader, total_time += time.time() - start # Evaluate the results of the current batch eval_class(post_result, batch) -# pbar.update(1) + # pbar.update(1) total_frame += len(images) - if idx % print_batch_step == 0: + if idx % print_batch_step == 0 and dist.get_rank() == 0: logger.info('tackling images for eval: {}/{}'.format( idx, len(valid_dataloader))) # Get final metirc,eg. acc or hmean metirc = eval_class.get_metric() + # pbar.close() model.train() metirc['fps'] = total_frame / total_time @@ -348,16 +345,15 @@ def preprocess(): device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' device = paddle.set_device(device) - + config['Global']['distributed'] = dist.get_world_size() != 1 - paddle.disable_static(device) # save_config save_model_dir = config['Global']['save_model_dir'] os.makedirs(save_model_dir, exist_ok=True) with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) - + logger = get_logger(log_file='{}/train.log'.format(save_model_dir)) if config['Global']['use_visualdl']: from visualdl import LogWriter diff --git a/tools/train.py b/tools/train.py index 54b9e25b..bdba7dba 100755 --- a/tools/train.py +++ b/tools/train.py @@ -27,9 +27,8 @@ import yaml import paddle import paddle.distributed as dist -paddle.manual_seed(2) +paddle.seed(2) -from ppocr.utils.logging import get_logger from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss @@ -49,18 +48,18 @@ def main(config, device, logger, vdl_writer): dist.init_parallel_env() global_config = config['Global'] - + # build dataloader - train_dataloader = build_dataloader(config, 'Train', device) + train_dataloader = build_dataloader(config, 'Train', device, logger) if config['Eval']: - valid_dataloader = build_dataloader(config, 'Eval', device) + valid_dataloader = build_dataloader(config, 'Eval', device, logger) else: valid_dataloader = None # build post process - post_process_class = build_post_process( - config['PostProcess'], global_config) - + post_process_class = build_post_process(config['PostProcess'], + global_config) + # build model #for rec algorithm if hasattr(post_process_class, 'character'): @@ -72,38 +71,29 @@ def main(config, device, logger, vdl_writer): # build loss loss_class = build_loss(config['Loss']) - + # build optim - optimizer, lr_scheduler = build_optimizer(config['Optimizer'], + optimizer, lr_scheduler = build_optimizer( + config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), parameters=model.parameters()) # build metric eval_class = build_metric(config['Metric']) - + # load pretrain model pre_best_model_dict = init_model(config, model, logger, optimizer) # start train - program.train(config, - train_dataloader, - valid_dataloader, - device, - model, - loss_class, - optimizer, - lr_scheduler, - post_process_class, - eval_class, - pre_best_model_dict, - logger, - vdl_writer) + program.train(config, train_dataloader, valid_dataloader, device, model, + loss_class, optimizer, lr_scheduler, post_process_class, + eval_class, pre_best_model_dict, logger, vdl_writer) def test_reader(config, device, logger): loader = build_dataloader(config, 'Train', device) -# loader = build_dataloader(config, 'Eval', device) + # loader = build_dataloader(config, 'Eval', device) import time starttime = time.time() count = 0 @@ -113,11 +103,13 @@ def test_reader(config, device, logger): if count % 1 == 0: batch_time = time.time() - starttime starttime = time.time() - logger.info("reader: {}, {}, {}".format(count, len(data), batch_time)) + logger.info("reader: {}, {}, {}".format(count, + len(data), batch_time)) except Exception as e: logger.info(e) logger.info("finish reader: {}, Success!".format(count)) + if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess() main(config, device, logger, vdl_writer)