From 5f2f08a09c04aec628d18d613eb806a5aa70f0d6 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 9 Dec 2020 14:59:04 +0800 Subject: [PATCH 1/7] add ppocr_v2 ch_db --- configs/det/ch_det_mv3_db.yml | 134 +++++++++++++++++++ configs/det/ch_det_res18_db.yml | 133 ++++++++++++++++++ ppocr/data/imaug/operators.py | 2 + ppocr/data/simple_dataset.py | 28 +++- ppocr/modeling/backbones/det_mobilenet_v3.py | 16 ++- ppocr/postprocess/db_postprocess.py | 10 +- ppocr/utils/save_load.py | 4 +- 7 files changed, 313 insertions(+), 14 deletions(-) create mode 100644 configs/det/ch_det_mv3_db.yml create mode 100644 configs/det/ch_det_res18_db.yml diff --git a/configs/det/ch_det_mv3_db.yml b/configs/det/ch_det_mv3_db.yml new file mode 100644 index 00000000..275c71b9 --- /dev/null +++ b/configs/det/ch_det_mv3_db.yml @@ -0,0 +1,134 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_mv3/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: MobileNetV3 + scale: 0.5 + model_name: large + disable_se: True + Neck: + name: DBFPN + out_channels: 96 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/configs/det/ch_det_res18_db.yml b/configs/det/ch_det_res18_db.yml new file mode 100644 index 00000000..9c903fa4 --- /dev/null +++ b/configs/det/ch_det_res18_db.yml @@ -0,0 +1,133 @@ +Global: + use_gpu: true + epoch_num: 1200 + log_smooth_window: 20 + print_batch_step: 2 + save_model_dir: ./output/ch_db_res18/ + save_epoch_step: 1200 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [3000, 2000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + load_static_weights: True + cal_metric_during_train: False + pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_en/img_10.jpg + save_res_path: ./output/det_db/predicts_db.txt + +Architecture: + model_type: det + algorithm: DB + Transform: + Backbone: + name: ResNet + layers: 18 + disable_se: True + Neck: + name: DBFPN + out_channels: 256 + Head: + name: DBHead + k: 50 + +Loss: + name: DBLoss + balance_loss: true + main_loss_type: DiceLoss + alpha: 5 + beta: 10 + ohem_ratio: 3 + +Optimizer: + name: Adam + beta1: 0.9 + beta2: 0.999 + lr: + name: Cosine + learning_rate: 0.001 + warmup_epoch: 2 + regularizer: + name: 'L2' + factor: 0 + +PostProcess: + name: DBPostProcess + thresh: 0.3 + box_thresh: 0.6 + max_candidates: 1000 + unclip_ratio: 1.5 + +Metric: + name: DetMetric + main_indicator: hmean + +Train: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt + ratio_list: [1.0] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - IaaAugment: + augmenter_args: + - { 'type': Fliplr, 'args': { 'p': 0.5 } } + - { 'type': Affine, 'args': { 'rotate': [-10, 10] } } + - { 'type': Resize, 'args': { 'size': [0.5, 3] } } + - EastRandomCropData: + size: [960, 960] + max_tries: 50 + keep_ratio: true + - MakeBorderMap: + shrink_ratio: 0.4 + thresh_min: 0.3 + thresh_max: 0.7 + - MakeShrinkMap: + shrink_ratio: 0.4 + min_text_size: 8 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list + loader: + shuffle: True + drop_last: False + batch_size_per_card: 8 + num_workers: 4 + +Eval: + dataset: + name: SimpleDataSet + data_dir: ./train_data/icdar2015/text_localization/ + label_file_list: + - ./train_data/icdar2015/text_localization/test_icdar2015_label.txt + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - DetLabelEncode: # Class handling label + - DetResizeForTest: +# image_shape: [736, 1280] + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: 'hwc' + - ToCHWImage: + - KeepKeys: + keep_keys: ['image', 'shape', 'polys', 'ignore_tags'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 1 # must be 1 + num_workers: 2 diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py index 74b60de4..927aa640 100644 --- a/ppocr/data/imaug/operators.py +++ b/ppocr/data/imaug/operators.py @@ -42,6 +42,8 @@ class DecodeImage(object): img) > 0, "invalid input 'img' in DecodeImage" img = np.frombuffer(img, dtype='uint8') img = cv2.imdecode(img, 1) + if img is None: + return None if self.img_mode == 'GRAY': img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) elif self.img_mode == 'RGB': diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 097da768..16891326 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -27,7 +27,10 @@ class SimpleDataSet(Dataset): global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] - batch_size = loader_config['batch_size_per_card'] + if 'data_num_per_epoch' in loader_config.keys(): + data_num_per_epoch = loader_config['data_num_per_epoch'] + else: + data_num_per_epoch = None self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') @@ -43,21 +46,34 @@ class SimpleDataSet(Dataset): self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list, + data_num_per_epoch) self.data_idx_order_list = list(range(len(self.data_lines))) if mode.lower() == "train": self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def get_image_info_list(self, file_list, ratio_list): + def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None): + sample_num = round(len(datas) * sample_ratio) + + if data_num_per_epoch is not None: + sample_num = data_num_per_epoch * sample_ratio + + nums, rem = sample_num // len(datas), sample_num % len(datas) + return list(datas) * nums + random.sample(datas, rem) + + def get_image_info_list(self, + file_list, + ratio_list, + data_num_per_epoch=None): if isinstance(file_list, str): file_list = [file_list] data_lines = [] for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - lines = random.sample(lines, - round(len(lines) * ratio_list[idx])) + lines = self._sample_dataset(lines, ratio_list[idx], + data_num_per_epoch) data_lines.extend(lines) return data_lines @@ -76,6 +92,8 @@ class SimpleDataSet(Dataset): label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index 017dce2f..d6b453d1 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None): class MobileNetV3(nn.Layer): - def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs): + def __init__(self, + in_channels=3, + model_name='large', + scale=0.5, + disable_se=False, + **kwargs): """ the MobilenetV3 backbone network for detection module. Args: params(dict): the super parameters for build network """ super(MobileNetV3, self).__init__() + + self.disable_se = disable_se + if model_name == "large": cfg = [ # k, exp, c, se, nl, s, @@ -223,7 +231,7 @@ class ResidualUnit(nn.Layer): if_act=True, act=act, name=name + "_depthwise") - if self.if_se: + if self.if_se and not self.disable_se: self.mid_se = SEModule(mid_channels, name=name + "_se") self.linear_conv = ConvBNLayer( in_channels=mid_channels, @@ -238,7 +246,7 @@ class ResidualUnit(nn.Layer): def forward(self, inputs): x = self.expand_conv(inputs) x = self.bottleneck_conv(x) - if self.if_se: + if self.if_se and not self.disable_se: x = self.mid_se(x) x = self.linear_conv(x) if self.if_shortcut: @@ -273,4 +281,4 @@ class SEModule(nn.Layer): outputs = F.relu(outputs) outputs = self.conv2(outputs) outputs = F.activation.hard_sigmoid(outputs) - return inputs * outputs \ No newline at end of file + return inputs * outputs diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index 316f7fc2..dc27abd6 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -39,6 +39,7 @@ class DBPostProcess(object): self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 + self.dilation_kernel = np.array([[1, 1], [1, 1]]) def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): ''' @@ -139,8 +140,11 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): height, width = shape_list[batch_index] - boxes, scores = self.boxes_from_bitmap( - pred[batch_index], segmentation[batch_index], width, height) + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, + width, height) boxes_batch.append({'points': boxes}) - return boxes_batch \ No newline at end of file + return boxes_batch diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py index 004322c8..af2de054 100644 --- a/ppocr/utils/save_load.py +++ b/ppocr/utils/save_load.py @@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False): weight_name = weight_name.replace('binarize', '').replace( 'thresh', '') # for DB if weight_name in pre_state_dict.keys(): - logger.info('Load weight: {}, shape: {}'.format( - weight_name, pre_state_dict[weight_name].shape)) + # logger.info('Load weight: {}, shape: {}'.format( + # weight_name, pre_state_dict[weight_name].shape)) if 'encoder_rnn' in key: # delete axis which is 1 pre_state_dict[weight_name] = pre_state_dict[ From c0b4cefdcbc8848cd573892df22d6434d35ba04b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 9 Dec 2020 20:26:40 +0800 Subject: [PATCH 2/7] fix comments and transform to transforms --- configs/det/{ => ch_ppocr_v1.1}/ch_det_mv3_db.yml | 0 configs/det/{ => ch_ppocr_v1.1}/ch_det_res18_db.yml | 2 +- ppocr/modeling/architectures/base_model.py | 2 +- ppocr/modeling/backbones/det_mobilenet_v3.py | 5 +++-- ppocr/modeling/{transform => transforms}/__init__.py | 0 ppocr/modeling/{transform => transforms}/tps.py | 0 ppocr/postprocess/db_postprocess.py | 12 ++++++++---- 7 files changed, 13 insertions(+), 8 deletions(-) rename configs/det/{ => ch_ppocr_v1.1}/ch_det_mv3_db.yml (100%) rename configs/det/{ => ch_ppocr_v1.1}/ch_det_res18_db.yml (97%) rename ppocr/modeling/{transform => transforms}/__init__.py (100%) rename ppocr/modeling/{transform => transforms}/tps.py (100%) diff --git a/configs/det/ch_det_mv3_db.yml b/configs/det/ch_ppocr_v1.1/ch_det_mv3_db.yml similarity index 100% rename from configs/det/ch_det_mv3_db.yml rename to configs/det/ch_ppocr_v1.1/ch_det_mv3_db.yml diff --git a/configs/det/ch_det_res18_db.yml b/configs/det/ch_ppocr_v1.1/ch_det_res18_db.yml similarity index 97% rename from configs/det/ch_det_res18_db.yml rename to configs/det/ch_ppocr_v1.1/ch_det_res18_db.yml index 9c903fa4..e34d9449 100644 --- a/configs/det/ch_det_res18_db.yml +++ b/configs/det/ch_ppocr_v1.1/ch_det_res18_db.yml @@ -10,7 +10,7 @@ Global: # if pretrained_model is saved in static mode, load_static_weights must set to True load_static_weights: True cal_metric_during_train: False - pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained + pretrained_model: ./pretrain_models/ResNet18_vd_pretrained checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy save_inference_dir: use_visualdl: False diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index 0c4fe650..ab44b53a 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -16,7 +16,7 @@ from __future__ import division from __future__ import print_function from paddle import nn -from ppocr.modeling.transform import build_transform +from ppocr.modeling.transforms import build_transform from ppocr.modeling.backbones import build_backbone from ppocr.modeling.necks import build_neck from ppocr.modeling.heads import build_head diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py index d6b453d1..f97bcfca 100755 --- a/ppocr/modeling/backbones/det_mobilenet_v3.py +++ b/ppocr/modeling/backbones/det_mobilenet_v3.py @@ -111,6 +111,7 @@ class MobileNetV3(nn.Layer): i = 0 inplanes = make_divisible(inplanes * scale) for (k, exp, c, se, nl, s) in cfg: + se = se and not self.disable_se if s == 2 and i > 2: self.out_channels.append(inplanes) self.stages.append(nn.Sequential(*block_list)) @@ -231,7 +232,7 @@ class ResidualUnit(nn.Layer): if_act=True, act=act, name=name + "_depthwise") - if self.if_se and not self.disable_se: + if self.if_se: self.mid_se = SEModule(mid_channels, name=name + "_se") self.linear_conv = ConvBNLayer( in_channels=mid_channels, @@ -246,7 +247,7 @@ class ResidualUnit(nn.Layer): def forward(self, inputs): x = self.expand_conv(inputs) x = self.bottleneck_conv(x) - if self.if_se and not self.disable_se: + if self.if_se: x = self.mid_se(x) x = self.linear_conv(x) if self.if_shortcut: diff --git a/ppocr/modeling/transform/__init__.py b/ppocr/modeling/transforms/__init__.py similarity index 100% rename from ppocr/modeling/transform/__init__.py rename to ppocr/modeling/transforms/__init__.py diff --git a/ppocr/modeling/transform/tps.py b/ppocr/modeling/transforms/tps.py similarity index 100% rename from ppocr/modeling/transform/tps.py rename to ppocr/modeling/transforms/tps.py diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index dc27abd6..b0a67b01 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -33,13 +33,14 @@ class DBPostProcess(object): box_thresh=0.7, max_candidates=1000, unclip_ratio=2.0, + use_dilation=False, **kwargs): self.thresh = thresh self.box_thresh = box_thresh self.max_candidates = max_candidates self.unclip_ratio = unclip_ratio self.min_size = 3 - self.dilation_kernel = np.array([[1, 1], [1, 1]]) + self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]] def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): ''' @@ -140,9 +141,12 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): height, width = shape_list[batch_index] - mask = cv2.dilate( - np.array(segmentation[batch_index]).astype(np.uint8), - self.dilation_kernel) + if self.dilation_kernel is not None: + mask = cv2.dilate( + np.array(segmentation[batch_index]).astype(np.uint8), + self.dilation_kernel) + else: + mask = segmentation[batch_index] boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, width, height) From 7cce85cc5c170aeb39f4fcfe0a32b0508323787e Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 9 Dec 2020 20:44:43 +0800 Subject: [PATCH 3/7] fix conflicts --- ppocr/postprocess/db_postprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/postprocess/db_postprocess.py b/ppocr/postprocess/db_postprocess.py index b0a67b01..0be2c12a 100644 --- a/ppocr/postprocess/db_postprocess.py +++ b/ppocr/postprocess/db_postprocess.py @@ -140,7 +140,7 @@ class DBPostProcess(object): boxes_batch = [] for batch_index in range(pred.shape[0]): - height, width = shape_list[batch_index] + src_h, src_w, ratio_h, ratio_w = shape_list[batch_index] if self.dilation_kernel is not None: mask = cv2.dilate( np.array(segmentation[batch_index]).astype(np.uint8), @@ -148,7 +148,7 @@ class DBPostProcess(object): else: mask = segmentation[batch_index] boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, - width, height) + src_w, src_h) boxes_batch.append({'points': boxes}) return boxes_batch From a5b219127fc4d316229775d956eba6a62fcebecd Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Wed, 9 Dec 2020 22:07:20 +0800 Subject: [PATCH 4/7] fix conflicts --- .../{ch_det_mv3_db.yml => ch_det_mv3_db_v2.0.yml} | 0 .../{ch_det_res18_db.yml => ch_det_res18_db_v2.0.yml} | 0 ppocr/data/simple_dataset.py | 4 ++-- 3 files changed, 2 insertions(+), 2 deletions(-) rename configs/det/ch_ppocr_v1.1/{ch_det_mv3_db.yml => ch_det_mv3_db_v2.0.yml} (100%) rename configs/det/ch_ppocr_v1.1/{ch_det_res18_db.yml => ch_det_res18_db_v2.0.yml} (100%) diff --git a/configs/det/ch_ppocr_v1.1/ch_det_mv3_db.yml b/configs/det/ch_ppocr_v1.1/ch_det_mv3_db_v2.0.yml similarity index 100% rename from configs/det/ch_ppocr_v1.1/ch_det_mv3_db.yml rename to configs/det/ch_ppocr_v1.1/ch_det_mv3_db_v2.0.yml diff --git a/configs/det/ch_ppocr_v1.1/ch_det_res18_db.yml b/configs/det/ch_ppocr_v1.1/ch_det_res18_db_v2.0.yml similarity index 100% rename from configs/det/ch_ppocr_v1.1/ch_det_res18_db.yml rename to configs/det/ch_ppocr_v1.1/ch_det_res18_db_v2.0.yml diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 16891326..d2069d0d 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -57,9 +57,9 @@ class SimpleDataSet(Dataset): sample_num = round(len(datas) * sample_ratio) if data_num_per_epoch is not None: - sample_num = data_num_per_epoch * sample_ratio + sample_num = int(data_num_per_epoch * sample_ratio) - nums, rem = sample_num // len(datas), sample_num % len(datas) + nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas)) return list(datas) * nums + random.sample(datas, rem) def get_image_info_list(self, From e23c4de5d8555da192d70e8ca1144c5a5b3a1e75 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 10 Dec 2020 10:12:50 +0800 Subject: [PATCH 5/7] 1.1 to 2.0 --- .../det/{ch_ppocr_v1.1 => ch_ppocr_v2.0}/ch_det_mv3_db_v2.0.yml | 0 .../det/{ch_ppocr_v1.1 => ch_ppocr_v2.0}/ch_det_res18_db_v2.0.yml | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename configs/det/{ch_ppocr_v1.1 => ch_ppocr_v2.0}/ch_det_mv3_db_v2.0.yml (100%) rename configs/det/{ch_ppocr_v1.1 => ch_ppocr_v2.0}/ch_det_res18_db_v2.0.yml (100%) diff --git a/configs/det/ch_ppocr_v1.1/ch_det_mv3_db_v2.0.yml b/configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml similarity index 100% rename from configs/det/ch_ppocr_v1.1/ch_det_mv3_db_v2.0.yml rename to configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml diff --git a/configs/det/ch_ppocr_v1.1/ch_det_res18_db_v2.0.yml b/configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml similarity index 100% rename from configs/det/ch_ppocr_v1.1/ch_det_res18_db_v2.0.yml rename to configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml From b8ba7035487d0436e69a2959e2daf71c4187bf9b Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 10 Dec 2020 10:19:39 +0800 Subject: [PATCH 6/7] delete data_num_per_epoch --- ppocr/data/simple_dataset.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index d2069d0d..1099fa44 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -27,17 +27,13 @@ class SimpleDataSet(Dataset): global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] - if 'data_num_per_epoch' in loader_config.keys(): - data_num_per_epoch = loader_config['data_num_per_epoch'] - else: - data_num_per_epoch = None self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) ratio_list = dataset_config.get("ratio_list", [1.0]) if isinstance(ratio_list, (float, int)): - ratio_list = [float(ratio_list)] * len(data_source_num) + ratio_list = [float(ratio_list)] * int(data_source_num) assert len( ratio_list @@ -46,34 +42,26 @@ class SimpleDataSet(Dataset): self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines = self.get_image_info_list(label_file_list, ratio_list, - data_num_per_epoch) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) if mode.lower() == "train": self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def _sample_dataset(self, datas, sample_ratio, data_num_per_epoch=None): + def _sample_dataset(self, datas, sample_ratio): sample_num = round(len(datas) * sample_ratio) - if data_num_per_epoch is not None: - sample_num = int(data_num_per_epoch * sample_ratio) - nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas)) return list(datas) * nums + random.sample(datas, rem) - def get_image_info_list(self, - file_list, - ratio_list, - data_num_per_epoch=None): + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] data_lines = [] for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - lines = self._sample_dataset(lines, ratio_list[idx], - data_num_per_epoch) + lines = self._sample_dataset(lines, ratio_list[idx]) data_lines.extend(lines) return data_lines From d97d98fe01dc3dc6e3e42c787269a7a89f96c4c2 Mon Sep 17 00:00:00 2001 From: LDOUBLEV Date: Thu, 10 Dec 2020 11:00:05 +0800 Subject: [PATCH 7/7] opt random sample --- ppocr/data/simple_dataset.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 1099fa44..ab17dd1a 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -48,12 +48,6 @@ class SimpleDataSet(Dataset): self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def _sample_dataset(self, datas, sample_ratio): - sample_num = round(len(datas) * sample_ratio) - - nums, rem = int(sample_num // len(datas)), int(sample_num % len(datas)) - return list(datas) * nums + random.sample(datas, rem) - def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] @@ -61,7 +55,8 @@ class SimpleDataSet(Dataset): for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - lines = self._sample_dataset(lines, ratio_list[idx]) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) data_lines.extend(lines) return data_lines