commit
de8b5b2593
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: output
|
||||
save_model_dir: ./output/det_db/
|
||||
save_epoch_step: 200
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 16
|
||||
|
@ -13,7 +13,7 @@ Global:
|
|||
reader_yml: ./configs/det/det_db_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
|
||||
checkpoints:
|
||||
save_res_path: ./output/predicts_db.txt
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: output
|
||||
save_model_dir: ./output/det_db/
|
||||
save_epoch_step: 200
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
|
@ -12,7 +12,9 @@ Global:
|
|||
image_shape: [3, 640, 640]
|
||||
reader_yml: ./configs/det/det_db_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
|
||||
save_res_path: ./output/predicts_db.txt
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 100000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: output
|
||||
save_model_dir: ./output/det_east/
|
||||
save_epoch_step: 200
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 16
|
||||
|
@ -12,7 +12,9 @@ Global:
|
|||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_east_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/MobileNetV3_pretrained/MobileNetV3_large_x0_5_pretrained/
|
||||
save_res_path: ./output/predicts_east.txt
|
||||
checkpoints:
|
||||
save_res_path: ./output/det_east/predicts_east.txt
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 100000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: output
|
||||
save_model_dir: ./output/det_east/
|
||||
save_epoch_step: 200
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
|
@ -12,7 +12,9 @@ Global:
|
|||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_east_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_pretrained/
|
||||
save_res_path: ./output/predicts_east.txt
|
||||
save_res_path: ./output/det_east/predicts_east.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_CRNN
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_Rosetta
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_RARE
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: attention
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_STARNet
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,9 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_CRNN
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_Rosetta
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_RARE
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: attention
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -4,7 +4,7 @@ Global:
|
|||
epoch_num: 72
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: output
|
||||
save_model_dir: output/rec_STARNet
|
||||
save_epoch_step: 3
|
||||
eval_batch_step: 2000
|
||||
train_batch_size_per_card: 256
|
||||
|
@ -15,6 +15,8 @@ Global:
|
|||
loss_type: ctc
|
||||
reader_yml: ./configs/rec/rec_benchmark_reader.yml
|
||||
pretrain_weights:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.rec_model,RecModel
|
||||
|
|
|
@ -196,7 +196,7 @@ class DBHead(object):
|
|||
fuse = fluid.layers.concat(input=[p5, p4, p3, p2], axis=1)
|
||||
shrink_maps = self.binarize(fuse)
|
||||
if mode != "train":
|
||||
return shrink_maps
|
||||
return {"maps", shrink_maps}
|
||||
threshold_maps = self.thresh(fuse)
|
||||
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
||||
y = fluid.layers.concat(
|
||||
|
|
|
@ -128,6 +128,7 @@ class DBPostProcess(object):
|
|||
|
||||
def __call__(self, outs_dict, ratio_list):
|
||||
pred = outs_dict['maps']
|
||||
|
||||
pred = pred[:, 0, :, :]
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ import copy
|
|||
import numpy as np
|
||||
import math
|
||||
import time
|
||||
import sys
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
|
@ -52,10 +53,10 @@ class TextDetector(object):
|
|||
utility.create_predictor(args, mode="det")
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
#######
|
||||
## https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
||||
########
|
||||
"""
|
||||
reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
|
||||
# sort the points based on their x-coordinates
|
||||
"""
|
||||
xSorted = pts[np.argsort(pts[:, 0]), :]
|
||||
|
||||
# grab the left-most and right-most points from the sorted
|
||||
|
@ -141,7 +142,7 @@ class TextDetector(object):
|
|||
outs_dict['f_score'] = outputs[0]
|
||||
outs_dict['f_geo'] = outputs[1]
|
||||
else:
|
||||
outs_dict['maps'] = [outputs[0]]
|
||||
outs_dict['maps'] = outputs[0]
|
||||
dt_boxes_list = self.postprocess_op(outs_dict, [ratio_list])
|
||||
dt_boxes = dt_boxes_list[0]
|
||||
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
||||
|
|
|
@ -219,6 +219,8 @@ def train_eval_det_run(config, exe, train_info_dict, eval_info_dict):
|
|||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
save_epoch_step = config['Global']['save_epoch_step']
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
if not os.path.exists(save_model_dir):
|
||||
os.makedirs(save_model_dir)
|
||||
train_stats = TrainingStats(log_smooth_window,
|
||||
train_info_dict['fetch_name_list'])
|
||||
best_eval_hmean = -1
|
||||
|
@ -282,6 +284,8 @@ def train_eval_rec_run(config, exe, train_info_dict, eval_info_dict):
|
|||
eval_batch_step = config['Global']['eval_batch_step']
|
||||
save_epoch_step = config['Global']['save_epoch_step']
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
if not os.path.exists(save_model_dir):
|
||||
os.makedirs(save_model_dir)
|
||||
train_stats = TrainingStats(log_smooth_window, ['loss', 'acc'])
|
||||
best_eval_acc = -1
|
||||
best_batch_id = 0
|
||||
|
|
Loading…
Reference in New Issue