commit
836839bbf6
|
@ -0,0 +1,134 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,133 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_res18/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet18_vd_pretrained
|
||||
checkpoints: #./output/det_db_0.001_DiceLoss_256_pp_config_2.0b_4gpu/best_accuracy
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DBPostProcess
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DetMetric
|
||||
main_indicator: hmean
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -42,6 +42,8 @@ class DecodeImage(object):
|
|||
img) > 0, "invalid input 'img' in DecodeImage"
|
||||
img = np.frombuffer(img, dtype='uint8')
|
||||
img = cv2.imdecode(img, 1)
|
||||
if img is None:
|
||||
return None
|
||||
if self.img_mode == 'GRAY':
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
elif self.img_mode == 'RGB':
|
||||
|
@ -197,7 +199,7 @@ class DetResizeForTest(object):
|
|||
ratio_w = resize_w / float(w)
|
||||
# return img, np.array([h, w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
|
|
|
@ -27,14 +27,13 @@ class SimpleDataSet(Dataset):
|
|||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
batch_size = loader_config['batch_size_per_card']
|
||||
|
||||
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * len(data_source_num)
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
|
||||
assert len(
|
||||
ratio_list
|
||||
|
@ -76,6 +75,8 @@ class SimpleDataSet(Dataset):
|
|||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
|
|
|
@ -16,7 +16,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
from ppocr.modeling.transform import build_transform
|
||||
from ppocr.modeling.transforms import build_transform
|
||||
from ppocr.modeling.backbones import build_backbone
|
||||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
|
|
|
@ -34,13 +34,21 @@ def make_divisible(v, divisor=8, min_value=None):
|
|||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self, in_channels=3, model_name='large', scale=0.5, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='large',
|
||||
scale=0.5,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
"""
|
||||
the MobilenetV3 backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for build network
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.disable_se = disable_se
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
|
@ -103,6 +111,7 @@ class MobileNetV3(nn.Layer):
|
|||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
if s == 2 and i > 2:
|
||||
self.out_channels.append(inplanes)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
@ -273,4 +282,4 @@ class SEModule(nn.Layer):
|
|||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.activation.hard_sigmoid(outputs)
|
||||
return inputs * outputs
|
||||
return inputs * outputs
|
||||
|
|
|
@ -33,12 +33,14 @@ class DBPostProcess(object):
|
|||
box_thresh=0.7,
|
||||
max_candidates=1000,
|
||||
unclip_ratio=2.0,
|
||||
use_dilation=False,
|
||||
**kwargs):
|
||||
self.thresh = thresh
|
||||
self.box_thresh = box_thresh
|
||||
self.max_candidates = max_candidates
|
||||
self.unclip_ratio = unclip_ratio
|
||||
self.min_size = 3
|
||||
self.dilation_kernel = None if not use_dilation else [[1, 1], [1, 1]]
|
||||
|
||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
||||
'''
|
||||
|
@ -139,8 +141,14 @@ class DBPostProcess(object):
|
|||
boxes_batch = []
|
||||
for batch_index in range(pred.shape[0]):
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
||||
boxes, scores = self.boxes_from_bitmap(
|
||||
pred[batch_index], segmentation[batch_index], src_w, src_h)
|
||||
if self.dilation_kernel is not None:
|
||||
mask = cv2.dilate(
|
||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||
self.dilation_kernel)
|
||||
else:
|
||||
mask = segmentation[batch_index]
|
||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||
src_w, src_h)
|
||||
|
||||
boxes_batch.append({'points': boxes})
|
||||
return boxes_batch
|
||||
return boxes_batch
|
||||
|
|
|
@ -55,8 +55,8 @@ def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
|||
weight_name = weight_name.replace('binarize', '').replace(
|
||||
'thresh', '') # for DB
|
||||
if weight_name in pre_state_dict.keys():
|
||||
logger.info('Load weight: {}, shape: {}'.format(
|
||||
weight_name, pre_state_dict[weight_name].shape))
|
||||
# logger.info('Load weight: {}, shape: {}'.format(
|
||||
# weight_name, pre_state_dict[weight_name].shape))
|
||||
if 'encoder_rnn' in key:
|
||||
# delete axis which is 1
|
||||
pre_state_dict[weight_name] = pre_state_dict[
|
||||
|
|
Loading…
Reference in New Issue