add sast code
This commit is contained in:
parent
a4d245185e
commit
224667b895
|
@ -0,0 +1,50 @@
|
|||
Global:
|
||||
algorithm: SAST
|
||||
use_gpu: true
|
||||
epoch_num: 2000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/det_sast/
|
||||
save_epoch_step: 20
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
test_batch_size_per_card: 8
|
||||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_sast_icdar15_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
save_res_path: ./output/det_sast/predicts_sast.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.det_sast_head,SASTHead
|
||||
model_name: large
|
||||
only_fpn_up: False
|
||||
# with_cab: False
|
||||
with_cab: True
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,RMSProp
|
||||
base_lr: 0.001
|
||||
decay:
|
||||
function: piecewise_decay
|
||||
boundaries: [30000, 50000, 80000, 100000, 150000]
|
||||
decay_rate: 0.3
|
||||
|
||||
PostProcess:
|
||||
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 2
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.0
|
||||
shrink_ratio_of_width: 0.3
|
|
@ -0,0 +1,50 @@
|
|||
Global:
|
||||
algorithm: SAST
|
||||
use_gpu: true
|
||||
epoch_num: 2000
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/det_sast/
|
||||
save_epoch_step: 20
|
||||
eval_batch_step: 5000
|
||||
train_batch_size_per_card: 8
|
||||
test_batch_size_per_card: 1
|
||||
image_shape: [3, 512, 512]
|
||||
reader_yml: ./configs/det/det_sast_totaltext_reader.yml
|
||||
pretrain_weights: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
save_res_path: ./output/det_sast/predicts_sast.txt
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
|
||||
Architecture:
|
||||
function: ppocr.modeling.architectures.det_model,DetModel
|
||||
|
||||
Backbone:
|
||||
function: ppocr.modeling.backbones.det_resnet_vd_sast,ResNet
|
||||
layers: 50
|
||||
|
||||
Head:
|
||||
function: ppocr.modeling.heads.det_sast_head,SASTHead
|
||||
model_name: large
|
||||
only_fpn_up: False
|
||||
# with_cab: False
|
||||
with_cab: True
|
||||
|
||||
Loss:
|
||||
function: ppocr.modeling.losses.det_sast_loss,SASTLoss
|
||||
|
||||
Optimizer:
|
||||
function: ppocr.optimizer,RMSProp
|
||||
base_lr: 0.001
|
||||
decay:
|
||||
function: piecewise_decay
|
||||
boundaries: [30000, 50000, 80000, 100000, 150000]
|
||||
decay_rate: 0.3
|
||||
|
||||
PostProcess:
|
||||
function: ppocr.postprocess.sast_postprocess,SASTPostProcess
|
||||
score_thresh: 0.5
|
||||
sample_pts_num: 6
|
||||
nms_thresh: 0.2
|
||||
expand_scale: 1.2
|
||||
shrink_ratio_of_width: 0.2
|
|
@ -0,0 +1,26 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,TrainReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTrain
|
||||
num_workers: 8
|
||||
img_set_dir: ./train_data/
|
||||
label_file_path: [./train_data/icdar13/train_label_json.txt, ./train_data/icdar15/train_label_json.txt, ./train_data/icdar17_mlt_latin/train_label_json.txt, ./train_data/coco_text_icdar_4pts/train_label_json.txt]
|
||||
data_ratio_list: [0.1, 0.45, 0.3, 0.15]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
max_side_len: 1536
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
infer_img:
|
||||
img_set_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_path: ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
do_eval: True
|
|
@ -0,0 +1,24 @@
|
|||
TrainReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,TrainReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTrain
|
||||
num_workers: 8
|
||||
img_set_dir: ./train_data/
|
||||
label_file_path: [./train_data/art_latin_icdar_14pt/train_no_tt_test/train_label_json.txt, ./train_data/total_text_icdar_14pt/train/train_label_json.txt]
|
||||
data_ratio_list: [0.5, 0.5]
|
||||
min_crop_side_ratio: 0.3
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
|
||||
EvalReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
img_set_dir: ./train_data/afs/
|
||||
label_file_path: ./train_data/afs/total_text/test_label_json.txt
|
||||
max_side_len: 768
|
||||
|
||||
TestReader:
|
||||
reader_function: ppocr.data.det.dataset_traversal,EvalTestReader
|
||||
process_function: ppocr.data.det.sast_process,SASTProcessTest
|
||||
infer_img:
|
||||
max_side_len: 768
|
|
@ -31,6 +31,11 @@ class TrainReader(object):
|
|||
def __init__(self, params):
|
||||
self.num_workers = params['num_workers']
|
||||
self.label_file_path = params['label_file_path']
|
||||
print(self.label_file_path)
|
||||
self.use_mul_data = False
|
||||
if isinstance(self.label_file_path, list):
|
||||
self.use_mul_data = True
|
||||
self.data_ratio_list = params['data_ratio_list']
|
||||
self.batch_size = params['train_batch_size_per_card']
|
||||
assert 'process_function' in params,\
|
||||
"absence process_function in Reader"
|
||||
|
@ -43,7 +48,7 @@ class TrainReader(object):
|
|||
img_num = len(label_infor_list)
|
||||
img_id_list = list(range(img_num))
|
||||
random.shuffle(img_id_list)
|
||||
if sys.platform == "win32":
|
||||
if sys.platform == "win32" and self.num_workers != 1:
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
|
@ -54,13 +59,64 @@ class TrainReader(object):
|
|||
continue
|
||||
yield outs
|
||||
|
||||
def sample_iter_reader_mul():
|
||||
batch_size = 1000
|
||||
data_source_list = self.label_file_path
|
||||
batch_size_list = list(map(int, [max(1.0, batch_size * x) for x in self.data_ratio_list]))
|
||||
print(self.data_ratio_list, batch_size_list)
|
||||
|
||||
data_filename_list, data_size_list, fetch_record_list = [], [], []
|
||||
for data_source in data_source_list:
|
||||
image_files = open(data_source, "rb").readlines()
|
||||
random.shuffle(image_files)
|
||||
data_filename_list.append(image_files)
|
||||
data_size_list.append(len(image_files))
|
||||
fetch_record_list.append(0)
|
||||
|
||||
image_batch, poly_batch = [], []
|
||||
# get a batch of img_fns and poly_fns
|
||||
for i in range(0, len(batch_size_list)):
|
||||
bs = batch_size_list[i]
|
||||
ds = data_size_list[i]
|
||||
image_names = data_filename_list[i]
|
||||
fetch_record = fetch_record_list[i]
|
||||
data_path = data_source_list[i]
|
||||
for j in range(fetch_record, fetch_record + bs):
|
||||
index = j % ds
|
||||
image_batch.append(image_names[index])
|
||||
|
||||
if (fetch_record + bs) > ds:
|
||||
fetch_record_list[i] = 0
|
||||
random.shuffle(data_filename_list[i])
|
||||
else:
|
||||
fetch_record_list[i] = fetch_record + bs
|
||||
|
||||
if sys.platform == "win32":
|
||||
print("multiprocess is not fully compatible with Windows."
|
||||
"num_workers will be 1.")
|
||||
self.num_workers = 1
|
||||
|
||||
for label_infor in image_batch:
|
||||
outs = self.process(label_infor)
|
||||
if outs is None:
|
||||
continue
|
||||
yield outs
|
||||
|
||||
def batch_iter_reader():
|
||||
batch_outs = []
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
if self.use_mul_data:
|
||||
print("Sample date from multiple datasets!")
|
||||
for outs in sample_iter_reader_mul():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
else:
|
||||
for outs in sample_iter_reader():
|
||||
batch_outs.append(outs)
|
||||
if len(batch_outs) == self.batch_size:
|
||||
yield batch_outs
|
||||
batch_outs = []
|
||||
|
||||
return batch_iter_reader
|
||||
|
||||
|
|
|
@ -0,0 +1,898 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
class SASTProcessTrain(object):
|
||||
"""
|
||||
SAST process function for training
|
||||
"""
|
||||
def __init__(self, params):
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.min_crop_side_ratio = params['min_crop_side_ratio']
|
||||
self.min_crop_size = params['min_crop_size']
|
||||
image_shape = params['image_shape']
|
||||
self.input_size = image_shape[1]
|
||||
self.min_text_size = params['min_text_size']
|
||||
self.max_text_size = params['max_text_size']
|
||||
|
||||
def convert_label_infor(self, label_infor):
|
||||
label_infor = label_infor.decode()
|
||||
label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
|
||||
substr = label_infor.strip("\n").split("\t")
|
||||
img_path = self.img_set_dir + substr[0]
|
||||
label = json.loads(substr[1])
|
||||
nBox = len(label)
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
for bno in range(0, nBox):
|
||||
wordBB = label[bno]['points']
|
||||
txt = label[bno]['transcription']
|
||||
wordBBs.append(wordBB)
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
wordBBs = np.array(wordBBs, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
return img_path, wordBBs, txt_tags, txts
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [
|
||||
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if True:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
||||
|
||||
def crop_area(self, im, polys, tags, hv_tags, txts, crop_background=False, max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys:
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w: maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h: maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags, txts
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
# if xmax - xmin < ARGS.min_crop_side_ratio * w or \
|
||||
# ymax - ymin < ARGS.min_crop_side_ratio * h:
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
# area too small
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
def generate_direction_map(self, poly_quads, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
|
||||
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
|
||||
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros((h, w,), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones((h, w,), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
|
||||
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
# continue
|
||||
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
# stcl map
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
# generate tcl map
|
||||
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
|
||||
# generate tbo map
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
|
||||
return score_map, tbo_map, training_mask
|
||||
|
||||
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
|
||||
"""
|
||||
Generate tcl map, tvo map and tbo map.
|
||||
"""
|
||||
h, w = hw
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
poly_mask = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
tvo_map = np.ones((9, h, w), dtype=np.float32)
|
||||
tvo_map[0:-1:2] = np.tile(np.arange(0, w), (h, 1))
|
||||
tvo_map[1:-1:2] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tv_xy_map = np.zeros((8, h, w), dtype=np.float32)
|
||||
|
||||
# tco map
|
||||
tco_map = np.ones((3, h, w), dtype=np.float32)
|
||||
tco_map[0] = np.tile(np.arange(0, w), (h, 1))
|
||||
tco_map[1] = np.tile(np.arange(0, w), (h, 1)).T
|
||||
poly_tc_xy_map = np.zeros((2, h, w), dtype=np.float32)
|
||||
|
||||
poly_short_edge_map = np.ones((h, w), dtype=np.float32)
|
||||
|
||||
for poly, poly_tag in zip(polys, tags):
|
||||
|
||||
if poly_tag == True:
|
||||
continue
|
||||
|
||||
# adjust point order for vertical poly
|
||||
poly = self.adjust_point(poly)
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
# generate tcl map and text, 128 * 128
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
|
||||
# generate poly_tv_xy_map
|
||||
for idx in range(4):
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
|
||||
# generate poly_tc_xy_map
|
||||
for idx in range(2):
|
||||
cv2.fillPoly(poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
|
||||
|
||||
# generate poly_short_edge_map
|
||||
cv2.fillPoly(poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
|
||||
# generate poly_mask and training_mask
|
||||
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
|
||||
|
||||
tvo_map *= poly_mask
|
||||
tvo_map[:8] -= poly_tv_xy_map
|
||||
tvo_map[-1] /= poly_short_edge_map
|
||||
tvo_map = tvo_map.transpose((1, 2, 0))
|
||||
|
||||
tco_map *= poly_mask
|
||||
tco_map[:2] -= poly_tc_xy_map
|
||||
tco_map[-1] /= poly_short_edge_map
|
||||
tco_map = tco_map.transpose((1, 2, 0))
|
||||
|
||||
return tvo_map, tco_map
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
#print("line1", line1)
|
||||
#print("line2", line2)
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
#if rand_degree_ratio > 0.333 and rand_degree_ratio < 0.666:
|
||||
# rand_degree_cnt = 2
|
||||
#elif rand_degree_ratio > 0.666:
|
||||
if rand_degree_ratio > 0.5:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4):#16->4
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
return dst_im, np.array(dst_polys, dtype=np.float32)
|
||||
|
||||
def extract_polys(self, poly_txt_path):
|
||||
"""
|
||||
Read text_polys, txt_tags, txts from give txt file.
|
||||
"""
|
||||
text_polys, txt_tags, txts = [], [], []
|
||||
|
||||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = map(float, poly_str.split(','))
|
||||
text_polys.append(np.array(poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
|
||||
return np.array(map(np.array, text_polys)), \
|
||||
np.array(txt_tags, dtype=np.bool), txts
|
||||
|
||||
def __call__(self, label_infor):
|
||||
infor = self.convert_label_infor(label_infor)
|
||||
im_path, text_polys, text_tags, text_strs = infor
|
||||
im = cv2.imread(im_path)
|
||||
if im is None:
|
||||
return None
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
# #add rotate cases
|
||||
# if np.random.rand() < 0.5:
|
||||
# im, text_polys = self.rotate_im_poly(im, text_polys)
|
||||
h, w, _ = im.shape
|
||||
# text_polys, text_tags = self.check_and_validate_polys(text_polys,
|
||||
# text_tags, h, w)
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
||||
# # random scale this image
|
||||
# rd_scale = np.random.choice(self.random_scale)
|
||||
# im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
# text_polys *= rd_scale
|
||||
# if np.random.rand() < self.background_ratio:
|
||||
# outs = self.crop_background_infor(im, text_polys, text_tags,
|
||||
# text_strs)
|
||||
# else:
|
||||
# outs = self.crop_foreground_infor(im, text_polys, text_tags,
|
||||
# text_strs)
|
||||
|
||||
# if outs is None:
|
||||
# return None
|
||||
# im, score_map, geo_map, training_mask = outs
|
||||
# score_map = score_map[np.newaxis, ::4, ::4].astype(np.float32)
|
||||
# geo_map = np.swapaxes(geo_map, 1, 2)
|
||||
# geo_map = np.swapaxes(geo_map, 1, 0)
|
||||
# geo_map = geo_map[:, ::4, ::4].astype(np.float32)
|
||||
# training_mask = training_mask[np.newaxis, ::4, ::4]
|
||||
# training_mask = training_mask.astype(np.float32)
|
||||
# return im, score_map, geo_map, training_mask
|
||||
|
||||
#set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
#no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(im, \
|
||||
text_polys, text_tags, hv_tags, text_strs, crop_background=False)
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
#resize image
|
||||
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
#add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks/2)*2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
#add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
#add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < self.input_size * 0.5:
|
||||
return None
|
||||
|
||||
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = self.input_size - new_h
|
||||
del_w = self.input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
|
||||
text_polys, text_tags, 0.25)
|
||||
|
||||
# SAST head
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
|
||||
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
|
||||
# images.append(im_padded[::-1, :, :])
|
||||
# tcl_maps.append(score_map[np.newaxis, :, :])
|
||||
# border_maps.append(border_map.transpose((2, 0, 1)))
|
||||
# training_masks.append(training_mask[np.newaxis, :, :])
|
||||
# tvos.append(tvo_map.transpose((2, 0, 1)))
|
||||
# tcos.append(tco_map.transpose((2, 0, 1)))
|
||||
|
||||
# # After a batch should begin
|
||||
# if len(images) == batch_size:
|
||||
# yield np.array(images, dtype=np.float32), \
|
||||
# np.array(tcl_maps, dtype=np.float32), \
|
||||
# np.array(tvos, dtype=np.float32), \
|
||||
# np.array(tcos, dtype=np.float32), \
|
||||
# np.array(border_maps, dtype=np.float32), \
|
||||
# np.array(training_masks, dtype=np.float32), \
|
||||
|
||||
# images, tcl_maps, border_maps, training_masks = [], [], [], []
|
||||
# tvos, tcos = [], []
|
||||
|
||||
# return im_padded, score_map, border_map, training_mask, tvo_map, tco_map
|
||||
return im_padded[::-1, :, :], score_map[np.newaxis, :, :], border_map.transpose((2, 0, 1)), training_mask[np.newaxis, :, :], tvo_map.transpose((2, 0, 1)), tco_map.transpose((2, 0, 1))
|
||||
|
||||
class SASTProcessTest(object):
|
||||
"""
|
||||
SAST process function for test
|
||||
"""
|
||||
def __init__(self, params):
|
||||
super(SASTProcessTest, self).__init__()
|
||||
if 'max_side_len' in params:
|
||||
self.max_side_len = params['max_side_len']
|
||||
else:
|
||||
self.max_side_len = 2400
|
||||
|
||||
# def resize_image(self, im):
|
||||
# """
|
||||
# resize image to a size multiple of 32 which is required by the network
|
||||
# :param im: the resized image
|
||||
# :param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
# :return: the resized image and the resize ratio
|
||||
# """
|
||||
# max_side_len = self.max_side_len
|
||||
# h, w, _ = im.shape
|
||||
|
||||
# resize_w = w
|
||||
# resize_h = h
|
||||
|
||||
# # limit the max side
|
||||
# if max(resize_h, resize_w) > max_side_len:
|
||||
# if resize_h > resize_w:
|
||||
# ratio = float(max_side_len) / resize_h
|
||||
# else:
|
||||
# ratio = float(max_side_len) / resize_w
|
||||
# else:
|
||||
# ratio = 1.
|
||||
# resize_h = int(resize_h * ratio)
|
||||
# resize_w = int(resize_w * ratio)
|
||||
# if resize_h % 32 == 0:
|
||||
# resize_h = resize_h
|
||||
# elif resize_h // 32 <= 1:
|
||||
# resize_h = 32
|
||||
# else:
|
||||
# resize_h = (resize_h // 32 - 1) * 32
|
||||
# if resize_w % 32 == 0:
|
||||
# resize_w = resize_w
|
||||
# elif resize_w // 32 <= 1:
|
||||
# resize_w = 32
|
||||
# else:
|
||||
# resize_w = (resize_w // 32 - 1) * 32
|
||||
# im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
# ratio_h = resize_h / float(h)
|
||||
# ratio_w = resize_w / float(w)
|
||||
# return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(self.max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def __call__(self, im):
|
||||
src_h, src_w, _ = im.shape
|
||||
im, (ratio_h, ratio_w) = self.resize_image(im)
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
im = im[:, :, ::-1].astype(np.float32)
|
||||
im = im / 255
|
||||
im -= img_mean
|
||||
im /= img_std
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im[np.newaxis, :]
|
||||
return [im, (ratio_h, ratio_w, src_h, src_w)]
|
|
@ -97,6 +97,24 @@ class DetModel(object):
|
|||
'shrink_mask':shrink_mask,\
|
||||
'threshold_map':threshold_map,\
|
||||
'threshold_mask':threshold_mask}
|
||||
elif self.algorithm == "SAST":
|
||||
input_score = fluid.layers.data(
|
||||
name='score', shape=[1, 128, 128], dtype='float32')
|
||||
input_border = fluid.layers.data(
|
||||
name='border', shape=[5, 128, 128], dtype='float32')
|
||||
input_mask = fluid.layers.data(
|
||||
name='mask', shape=[1, 128, 128], dtype='float32')
|
||||
input_tvo = fluid.layers.data(
|
||||
# name='tvo', shape=[5, 128, 128], dtype='float32')
|
||||
name='tvo', shape=[9, 128, 128], dtype='float32')
|
||||
input_tco = fluid.layers.data(
|
||||
name='tco', shape=[3, 128, 128], dtype='float32')
|
||||
feed_list = [image, input_score, input_border, input_mask, input_tvo, input_tco]
|
||||
labels = {'input_score': input_score,\
|
||||
'input_border': input_border,\
|
||||
'input_mask': input_mask,\
|
||||
'input_tvo': input_tvo,\
|
||||
'input_tco': input_tco}
|
||||
loader = fluid.io.DataLoader.from_generator(
|
||||
feed_list=feed_list,
|
||||
capacity=64,
|
||||
|
|
|
@ -0,0 +1,274 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ResNet(object):
|
||||
def __init__(self, params):
|
||||
"""
|
||||
the Resnet backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for network build
|
||||
"""
|
||||
self.layers = params['layers']
|
||||
supported_layers = [18, 34, 50, 101, 152]
|
||||
assert self.layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(supported_layers, self.layers)
|
||||
self.is_3x3 = True
|
||||
|
||||
def __call__(self, input):
|
||||
layers = self.layers
|
||||
is_3x3 = self.is_3x3
|
||||
# if layers == 18:
|
||||
# depth = [2, 2, 2, 2]
|
||||
# elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
# elif layers == 101:
|
||||
# depth = [3, 4, 23, 3]
|
||||
# elif layers == 152:
|
||||
# depth = [3, 8, 36, 3]
|
||||
# elif layers == 200:
|
||||
# depth = [3, 12, 48, 3]
|
||||
# num_filters = [64, 128, 256, 512]
|
||||
# outs = []
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]#, 3, 3]
|
||||
elif layers == 34 or layers == 50:
|
||||
#depth = [3, 4, 6, 3]#, 3, 3]
|
||||
depth = [3, 4, 6, 3, 3]#, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]#, 3, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]#, 3, 3]
|
||||
num_filters = [64, 128, 256, 512, 512]#, 512]
|
||||
blocks = {}
|
||||
|
||||
idx = 'block_0'
|
||||
blocks[idx] = input
|
||||
|
||||
if is_3x3 == False:
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=64,
|
||||
filter_size=7,
|
||||
stride=2,
|
||||
act='relu')
|
||||
else:
|
||||
conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name='conv1_1')
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=32,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='conv1_2')
|
||||
conv = self.conv_bn_layer(
|
||||
input=conv,
|
||||
num_filters=64,
|
||||
filter_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='conv1_3')
|
||||
idx = 'block_1'
|
||||
blocks[idx] = conv
|
||||
|
||||
conv = fluid.layers.pool2d(
|
||||
input=conv,
|
||||
pool_size=3,
|
||||
pool_stride=2,
|
||||
pool_padding=1,
|
||||
pool_type='max')
|
||||
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152, 200] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.bottleneck_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name)
|
||||
# outs.append(conv)
|
||||
idx = 'block_' + str(block + 2)
|
||||
blocks[idx] = conv
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
conv = self.basic_block(
|
||||
input=conv,
|
||||
num_filters=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name)
|
||||
# outs.append(conv)
|
||||
idx = 'block_' + str(block + 2)
|
||||
blocks[idx] = conv
|
||||
# return outs
|
||||
return blocks
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def conv_bn_layer_new(self,
|
||||
input,
|
||||
num_filters,
|
||||
filter_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
act=None,
|
||||
name=None):
|
||||
pool = fluid.layers.pool2d(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
pool_padding=0,
|
||||
pool_type='avg',
|
||||
ceil_mode=True)
|
||||
|
||||
conv = fluid.layers.conv2d(
|
||||
input=pool,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=1,
|
||||
padding=(filter_size - 1) // 2,
|
||||
groups=groups,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def shortcut(self, input, ch_out, stride, name, if_first=False):
|
||||
ch_in = input.shape[1]
|
||||
if ch_in != ch_out or stride != 1:
|
||||
if if_first:
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
else:
|
||||
return self.conv_bn_layer_new(
|
||||
input, ch_out, 1, stride, name=name)
|
||||
elif if_first:
|
||||
return self.conv_bn_layer(input, ch_out, 1, stride, name=name)
|
||||
else:
|
||||
return input
|
||||
|
||||
def bottleneck_block(self, input, num_filters, stride, name, if_first):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
conv2 = self.conv_bn_layer(
|
||||
input=conv1,
|
||||
num_filters=num_filters * 4,
|
||||
filter_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters * 4,
|
||||
stride,
|
||||
if_first=if_first,
|
||||
name=name + "_branch1")
|
||||
|
||||
return fluid.layers.elementwise_add(x=short, y=conv2, act='relu')
|
||||
|
||||
def basic_block(self, input, num_filters, stride, name, if_first):
|
||||
conv0 = self.conv_bn_layer(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act='relu',
|
||||
stride=stride,
|
||||
name=name + "_branch2a")
|
||||
conv1 = self.conv_bn_layer(
|
||||
input=conv0,
|
||||
num_filters=num_filters,
|
||||
filter_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
short = self.shortcut(
|
||||
input,
|
||||
num_filters,
|
||||
stride,
|
||||
if_first=if_first,
|
||||
name=name + "_branch1")
|
||||
return fluid.layers.elementwise_add(x=short, y=conv1, act='relu')
|
|
@ -0,0 +1,228 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
from ..common_functions import conv_bn_layer, deconv_bn_layer
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class SASTHead(object):
|
||||
"""
|
||||
SAST:
|
||||
see arxiv: https://
|
||||
args:
|
||||
params(dict): the super parameters for network build
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.model_name = params['model_name']
|
||||
self.with_cab = params['with_cab']
|
||||
|
||||
def FPN_Up_Fusion(self, blocks):
|
||||
"""
|
||||
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
|
||||
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
|
||||
"""
|
||||
f = [blocks['block_6'], blocks['block_5'], blocks['block_4'], blocks['block_3'], blocks['block_2']]
|
||||
num_outputs = [256, 256, 192, 192, 128]
|
||||
g = [None, None, None, None, None]
|
||||
h = [None, None, None, None, None]
|
||||
for i in range(5):
|
||||
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
|
||||
filter_size=1, stride=1, act=None, name='fpn_up_h'+str(i))
|
||||
|
||||
for i in range(4):
|
||||
if i == 0:
|
||||
g[i] = deconv_bn_layer(input=h[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g0')
|
||||
print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
else:
|
||||
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
||||
g[i] = fluid.layers.relu(g[i])
|
||||
#g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
||||
# filter_size=1, stride=1, act='relu')
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_up_g%d_1'%i)
|
||||
g[i] = deconv_bn_layer(input=g[i], num_filters=num_outputs[i + 1], act=None, name='fpn_up_g%d_2'%i)
|
||||
print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
|
||||
g[4] = fluid.layers.elementwise_add(x=g[3], y=h[4])
|
||||
g[4] = fluid.layers.relu(g[4])
|
||||
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_up_fusion_1')
|
||||
g[4] = conv_bn_layer(input=g[4], num_filters=num_outputs[4],
|
||||
filter_size=1, stride=1, act=None, name='fpn_up_fusion_2')
|
||||
|
||||
return g[4]
|
||||
|
||||
def FPN_Down_Fusion(self, blocks):
|
||||
"""
|
||||
blocks{}: contain block_2, block_3, block_4, block_5, block_6, block_7 with
|
||||
1/4, 1/8, 1/16, 1/32, 1/64, 1/128 resolution.
|
||||
"""
|
||||
f = [blocks['block_0'], blocks['block_1'], blocks['block_2']]
|
||||
num_outputs = [32, 64, 128]
|
||||
g = [None, None, None]
|
||||
h = [None, None, None]
|
||||
for i in range(3):
|
||||
h[i] = conv_bn_layer(input=f[i], num_filters=num_outputs[i],
|
||||
filter_size=3, stride=1, act=None, name='fpn_down_h'+str(i))
|
||||
for i in range(2):
|
||||
if i == 0:
|
||||
g[i] = conv_bn_layer(input=h[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g0')
|
||||
else:
|
||||
g[i] = fluid.layers.elementwise_add(x=g[i - 1], y=h[i])
|
||||
g[i] = fluid.layers.relu(g[i])
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i], filter_size=3, stride=1, act='relu', name='fpn_down_g%d_1'%i)
|
||||
g[i] = conv_bn_layer(input=g[i], num_filters=num_outputs[i+1], filter_size=3, stride=2, act=None, name='fpn_down_g%d_2'%i)
|
||||
print("g[{}] shape: {}".format(i, g[i].shape))
|
||||
g[2] = fluid.layers.elementwise_add(x=g[1], y=h[2])
|
||||
g[2] = fluid.layers.relu(g[2])
|
||||
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
|
||||
filter_size=3, stride=1, act='relu', name='fpn_down_fusion_1')
|
||||
g[2] = conv_bn_layer(input=g[2], num_filters=num_outputs[2],
|
||||
filter_size=1, stride=1, act=None, name='fpn_down_fusion_2')
|
||||
return g[2]
|
||||
|
||||
def SAST_Header1(self, f_common):
|
||||
"""Detector header."""
|
||||
#f_score
|
||||
f_score = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_score1')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=64, filter_size=3, stride=1, act='relu', name='f_score2')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=128, filter_size=1, stride=1, act='relu', name='f_score3')
|
||||
f_score = conv_bn_layer(input=f_score, num_filters=1, filter_size=3, stride=1, name='f_score4')
|
||||
f_score = fluid.layers.sigmoid(f_score)
|
||||
print("f_score shape: {}".format(f_score.shape))
|
||||
|
||||
#f_boder
|
||||
f_border = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_border1')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=64, filter_size=3, stride=1, act='relu', name='f_border2')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=128, filter_size=1, stride=1, act='relu', name='f_border3')
|
||||
f_border = conv_bn_layer(input=f_border, num_filters=4, filter_size=3, stride=1, name='f_border4')
|
||||
print("f_border shape: {}".format(f_border.shape))
|
||||
|
||||
return f_score, f_border
|
||||
|
||||
def SAST_Header2(self, f_common):
|
||||
"""Detector header."""
|
||||
#f_tvo
|
||||
f_tvo = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tvo1')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tvo2')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tvo3')
|
||||
f_tvo = conv_bn_layer(input=f_tvo, num_filters=8, filter_size=3, stride=1, name='f_tvo4')
|
||||
print("f_tvo shape: {}".format(f_tvo.shape))
|
||||
|
||||
#f_tco
|
||||
f_tco = conv_bn_layer(input=f_common, num_filters=64, filter_size=1, stride=1, act='relu', name='f_tco1')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=64, filter_size=3, stride=1, act='relu', name='f_tco2')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=128, filter_size=1, stride=1, act='relu', name='f_tco3')
|
||||
f_tco = conv_bn_layer(input=f_tco, num_filters=2, filter_size=3, stride=1, name='f_tco4')
|
||||
print("f_tco shape: {}".format(f_tco.shape))
|
||||
|
||||
return f_tvo, f_tco
|
||||
|
||||
def cross_attention(self, f_common):
|
||||
"""
|
||||
"""
|
||||
f_shape = fluid.layers.shape(f_common)
|
||||
f_theta = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_theta')
|
||||
f_phi = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_phi')
|
||||
f_g = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, act='relu', name='f_g')
|
||||
### horizon
|
||||
fh_theta = f_theta
|
||||
fh_phi = f_phi
|
||||
fh_g = f_g
|
||||
#flatten
|
||||
fh_theta = fluid.layers.transpose(fh_theta, [0, 2, 3, 1])
|
||||
fh_theta = fluid.layers.reshape(fh_theta, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
fh_phi = fluid.layers.transpose(fh_phi, [0, 2, 3, 1])
|
||||
fh_phi = fluid.layers.reshape(fh_phi, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
fh_g = fluid.layers.transpose(fh_g, [0, 2, 3, 1])
|
||||
fh_g = fluid.layers.reshape(fh_g, [f_shape[0] * f_shape[2], f_shape[3], 128])
|
||||
#correlation
|
||||
fh_attn = fluid.layers.matmul(fh_theta, fluid.layers.transpose(fh_phi, [0, 2, 1]))
|
||||
#scale
|
||||
fh_attn = fh_attn / (128 ** 0.5)
|
||||
fh_attn = fluid.layers.softmax(fh_attn)
|
||||
#weighted sum
|
||||
fh_weight = fluid.layers.matmul(fh_attn, fh_g)
|
||||
fh_weight = fluid.layers.reshape(fh_weight, [f_shape[0], f_shape[2], f_shape[3], 128])
|
||||
print("fh_weight: {}".format(fh_weight.shape))
|
||||
fh_weight = fluid.layers.transpose(fh_weight, [0, 3, 1, 2])
|
||||
fh_weight = conv_bn_layer(input=fh_weight, num_filters=128, filter_size=1, stride=1, name='fh_weight')
|
||||
#short cut
|
||||
fh_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fh_sc')
|
||||
f_h = fluid.layers.relu(fh_weight + fh_sc)
|
||||
######
|
||||
#vertical
|
||||
fv_theta = fluid.layers.transpose(f_theta, [0, 1, 3, 2])
|
||||
fv_phi = fluid.layers.transpose(f_phi, [0, 1, 3, 2])
|
||||
fv_g = fluid.layers.transpose(f_g, [0, 1, 3, 2])
|
||||
#flatten
|
||||
fv_theta = fluid.layers.transpose(fv_theta, [0, 2, 3, 1])
|
||||
fv_theta = fluid.layers.reshape(fv_theta, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
fv_phi = fluid.layers.transpose(fv_phi, [0, 2, 3, 1])
|
||||
fv_phi = fluid.layers.reshape(fv_phi, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
fv_g = fluid.layers.transpose(fv_g, [0, 2, 3, 1])
|
||||
fv_g = fluid.layers.reshape(fv_g, [f_shape[0] * f_shape[3], f_shape[2], 128])
|
||||
#correlation
|
||||
fv_attn = fluid.layers.matmul(fv_theta, fluid.layers.transpose(fv_phi, [0, 2, 1]))
|
||||
#scale
|
||||
fv_attn = fv_attn / (128 ** 0.5)
|
||||
fv_attn = fluid.layers.softmax(fv_attn)
|
||||
#weighted sum
|
||||
fv_weight = fluid.layers.matmul(fv_attn, fv_g)
|
||||
fv_weight = fluid.layers.reshape(fv_weight, [f_shape[0], f_shape[3], f_shape[2], 128])
|
||||
print("fv_weight: {}".format(fv_weight.shape))
|
||||
fv_weight = fluid.layers.transpose(fv_weight, [0, 3, 2, 1])
|
||||
fv_weight = conv_bn_layer(input=fv_weight, num_filters=128, filter_size=1, stride=1, name='fv_weight')
|
||||
#short cut
|
||||
fv_sc = conv_bn_layer(input=f_common, num_filters=128, filter_size=1, stride=1, name='fv_sc')
|
||||
f_v = fluid.layers.relu(fv_weight + fv_sc)
|
||||
######
|
||||
f_attn = fluid.layers.concat([f_h, f_v], axis=1)
|
||||
f_attn = conv_bn_layer(input=f_attn, num_filters=128, filter_size=1, stride=1, act='relu', name='f_attn')
|
||||
return f_attn
|
||||
|
||||
def __call__(self, blocks, with_cab=False):
|
||||
for k, v in blocks.items():
|
||||
print(k, v.shape)
|
||||
|
||||
#down fpn
|
||||
f_down = self.FPN_Down_Fusion(blocks)
|
||||
print("f_down shape: {}".format(f_down.shape))
|
||||
#up fpn
|
||||
f_up = self.FPN_Up_Fusion(blocks)
|
||||
print("f_up shape: {}".format(f_up.shape))
|
||||
#fusion
|
||||
f_common = fluid.layers.elementwise_add(x=f_down, y=f_up)
|
||||
f_common = fluid.layers.relu(f_common)
|
||||
print("f_common: {}".format(f_common.shape))
|
||||
|
||||
if self.with_cab:
|
||||
print('enhence f_common with CAB.')
|
||||
f_common = self.cross_attention(f_common)
|
||||
|
||||
f_score, f_border= self.SAST_Header1(f_common)
|
||||
f_tvo, f_tco = self.SAST_Header2(f_common)
|
||||
|
||||
predicts = OrderedDict()
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_tvo'] = f_tvo
|
||||
predicts['f_tco'] = f_tco
|
||||
return predicts
|
|
@ -0,0 +1,115 @@
|
|||
#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
#Licensed under the Apache License, Version 2.0 (the "License");
|
||||
#you may not use this file except in compliance with the License.
|
||||
#You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
#Unless required by applicable law or agreed to in writing, software
|
||||
#distributed under the License is distributed on an "AS IS" BASIS,
|
||||
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
#See the License for the specific language governing permissions and
|
||||
#limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle.fluid as fluid
|
||||
|
||||
|
||||
class SASTLoss(object):
|
||||
"""
|
||||
SAST Loss function
|
||||
"""
|
||||
|
||||
def __init__(self, params=None):
|
||||
super(SASTLoss, self).__init__()
|
||||
|
||||
def __call__(self, predicts, labels):
|
||||
"""
|
||||
tcl_pos: N x 128 x 3
|
||||
tcl_mask: N x 128 x 1
|
||||
tcl_label: N x X list or LoDTensor
|
||||
"""
|
||||
|
||||
f_score = predicts['f_score']
|
||||
f_border = predicts['f_border']
|
||||
f_tvo = predicts['f_tvo']
|
||||
f_tco = predicts['f_tco']
|
||||
|
||||
l_score = labels['input_score']
|
||||
l_border = labels['input_border']
|
||||
l_mask = labels['input_mask']
|
||||
l_tvo = labels['input_tvo']
|
||||
l_tco = labels['input_tco']
|
||||
|
||||
#score_loss
|
||||
intersection = fluid.layers.reduce_sum(f_score * l_score * l_mask)
|
||||
union = fluid.layers.reduce_sum(f_score * l_mask) + fluid.layers.reduce_sum(l_score * l_mask)
|
||||
score_loss = 1.0 - 2 * intersection / (union + 1e-5)
|
||||
|
||||
#border loss
|
||||
l_border_split, l_border_norm = fluid.layers.split(l_border, num_or_sections=[4, 1], dim=1)
|
||||
f_border_split = f_border
|
||||
l_border_norm_split = fluid.layers.expand(x=l_border_norm, expand_times=[1, 4, 1, 1])
|
||||
l_border_score = fluid.layers.expand(x=l_score, expand_times=[1, 4, 1, 1])
|
||||
l_border_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 4, 1, 1])
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = fluid.layers.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = fluid.layers.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = fluid.layers.reduce_sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(fluid.layers.reduce_sum(l_border_score * l_border_mask) + 1e-5)
|
||||
|
||||
#tvo_loss
|
||||
l_tvo_split, l_tvo_norm = fluid.layers.split(l_tvo, num_or_sections=[8, 1], dim=1)
|
||||
f_tvo_split = f_tvo
|
||||
l_tvo_norm_split = fluid.layers.expand(x=l_tvo_norm, expand_times=[1, 8, 1, 1])
|
||||
l_tvo_score = fluid.layers.expand(x=l_score, expand_times=[1, 8, 1, 1])
|
||||
l_tvo_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 8, 1, 1])
|
||||
#
|
||||
tvo_geo_diff = l_tvo_split - f_tvo_split
|
||||
abs_tvo_geo_diff = fluid.layers.abs(tvo_geo_diff)
|
||||
tvo_sign = abs_tvo_geo_diff < 1.0
|
||||
tvo_sign = fluid.layers.cast(tvo_sign, dtype='float32')
|
||||
tvo_sign.stop_gradient = True
|
||||
tvo_in_loss = 0.5 * abs_tvo_geo_diff * abs_tvo_geo_diff * tvo_sign + \
|
||||
(abs_tvo_geo_diff - 0.5) * (1.0 - tvo_sign)
|
||||
tvo_out_loss = l_tvo_norm_split * tvo_in_loss
|
||||
tvo_loss = fluid.layers.reduce_sum(tvo_out_loss * l_tvo_score * l_tvo_mask) / \
|
||||
(fluid.layers.reduce_sum(l_tvo_score * l_tvo_mask) + 1e-5)
|
||||
|
||||
#tco_loss
|
||||
l_tco_split, l_tco_norm = fluid.layers.split(l_tco, num_or_sections=[2, 1], dim=1)
|
||||
f_tco_split = f_tco
|
||||
l_tco_norm_split = fluid.layers.expand(x=l_tco_norm, expand_times=[1, 2, 1, 1])
|
||||
l_tco_score = fluid.layers.expand(x=l_score, expand_times=[1, 2, 1, 1])
|
||||
l_tco_mask = fluid.layers.expand(x=l_mask, expand_times=[1, 2, 1, 1])
|
||||
#
|
||||
tco_geo_diff = l_tco_split - f_tco_split
|
||||
abs_tco_geo_diff = fluid.layers.abs(tco_geo_diff)
|
||||
tco_sign = abs_tco_geo_diff < 1.0
|
||||
tco_sign = fluid.layers.cast(tco_sign, dtype='float32')
|
||||
tco_sign.stop_gradient = True
|
||||
tco_in_loss = 0.5 * abs_tco_geo_diff * abs_tco_geo_diff * tco_sign + \
|
||||
(abs_tco_geo_diff - 0.5) * (1.0 - tco_sign)
|
||||
tco_out_loss = l_tco_norm_split * tco_in_loss
|
||||
tco_loss = fluid.layers.reduce_sum(tco_out_loss * l_tco_score * l_tco_mask) / \
|
||||
(fluid.layers.reduce_sum(l_tco_score * l_tco_mask) + 1e-5)
|
||||
|
||||
|
||||
# total loss
|
||||
tvo_lw, tco_lw = 1.5, 1.5
|
||||
score_lw, border_lw = 1.0, 1.0
|
||||
total_loss = score_loss * score_lw + border_loss * border_lw + \
|
||||
tvo_loss * tvo_lw + tco_loss * tco_lw
|
||||
|
||||
losses = {'total_loss':total_loss, "score_loss":score_loss,\
|
||||
"border_loss":border_loss, 'tvo_loss':tvo_loss, 'tco_loss':tco_loss}
|
||||
return losses
|
|
@ -36,17 +36,28 @@ def AdamDecay(params, parameter_list=None):
|
|||
l2_decay = params.get("l2_decay", 0.0)
|
||||
|
||||
if 'decay' in params:
|
||||
supported_decay_mode = ["cosine_decay", "piecewise_decay"]
|
||||
params = params['decay']
|
||||
decay_mode = params['function']
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
|
||||
supported_decay_mode, decay_mode)
|
||||
|
||||
if decay_mode == "cosine_decay":
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
base_lr = fluid.layers.cosine_decay(
|
||||
learning_rate=base_lr,
|
||||
step_each_epoch=step_each_epoch,
|
||||
epochs=total_epoch)
|
||||
else:
|
||||
logger.info("Only support Cosine decay currently")
|
||||
elif decay_mode == "piecewise_decay":
|
||||
boundaries = params["boundaries"]
|
||||
decay_rate = params["decay_rate"]
|
||||
values = [
|
||||
base_lr * decay_rate**idx
|
||||
for idx in range(len(boundaries) + 1)
|
||||
]
|
||||
base_lr = fluid.layers.piecewise_decay(boundaries, values)
|
||||
|
||||
optimizer = fluid.optimizer.Adam(
|
||||
learning_rate=base_lr,
|
||||
beta1=beta1,
|
||||
|
@ -54,3 +65,44 @@ def AdamDecay(params, parameter_list=None):
|
|||
regularization=L2Decay(regularization_coeff=l2_decay),
|
||||
parameter_list=parameter_list)
|
||||
return optimizer
|
||||
|
||||
|
||||
def RMSProp(params, parameter_list=None):
|
||||
"""
|
||||
define optimizer function
|
||||
args:
|
||||
params(dict): the super parameters
|
||||
parameter_list (list): list of Variable names to update to minimize loss
|
||||
return:
|
||||
"""
|
||||
base_lr = params.get("base_lr", 0.001)
|
||||
l2_decay = params.get("l2_decay", 0.00005)
|
||||
|
||||
if 'decay' in params:
|
||||
supported_decay_mode = ["cosine_decay", "piecewise_decay"]
|
||||
params = params['decay']
|
||||
decay_mode = params['function']
|
||||
assert decay_mode in supported_decay_mode, "Supported decay mode is {}, but got {}".format(
|
||||
supported_decay_mode, decay_mode)
|
||||
|
||||
if decay_mode == "cosine_decay":
|
||||
step_each_epoch = params['step_each_epoch']
|
||||
total_epoch = params['total_epoch']
|
||||
base_lr = fluid.layers.cosine_decay(
|
||||
learning_rate=base_lr,
|
||||
step_each_epoch=step_each_epoch,
|
||||
epochs=total_epoch)
|
||||
elif decay_mode == "piecewise_decay":
|
||||
boundaries = params["boundaries"]
|
||||
decay_rate = params["decay_rate"]
|
||||
values = [
|
||||
base_lr * decay_rate**idx
|
||||
for idx in range(len(boundaries) + 1)
|
||||
]
|
||||
base_lr = fluid.layers.piecewise_decay(boundaries, values)
|
||||
|
||||
optimizer = fluid.optimizer.RMSProp(
|
||||
learning_rate=base_lr,
|
||||
regularization=fluid.regularizer.L2Decay(regularization_coeff=l2_decay))
|
||||
|
||||
return optimizer
|
|
@ -0,0 +1,289 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
import numpy as np
|
||||
from .locality_aware_nms import nms_locality
|
||||
# import lanms
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
class SASTPostProcess(object):
|
||||
"""
|
||||
The post process for SAST.
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
self.score_thresh = params.get('score_thresh', 0.5)
|
||||
self.nms_thresh = params.get('nms_thresh', 0.2)
|
||||
self.sample_pts_num = params.get('sample_pts_num', 2)
|
||||
self.shrink_ratio_of_width = params.get('shrink_ratio_of_width', 0.3)
|
||||
self.expand_scale = params.get('expand_scale', 1.0)
|
||||
self.tcl_map_thresh = 0.5
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def point_pair2poly(self, point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
# constract poly
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||
"""Restore quad."""
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
|
||||
# Sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||
|
||||
scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
scores = scores[:, np.newaxis]
|
||||
|
||||
# Restore
|
||||
point_num = int(tvo_map.shape[-1] / 2)
|
||||
assert point_num == 4
|
||||
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
quads = xy_text_tile - tvo_map
|
||||
|
||||
return scores, quads, xy_text
|
||||
|
||||
def quad_area(self, quad):
|
||||
"""
|
||||
compute area of a quad.
|
||||
"""
|
||||
edge = [
|
||||
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||
]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def nms(self, dets):
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh)
|
||||
else:
|
||||
dets = nms_locality(dets, self.nms_thresh)
|
||||
return dets
|
||||
|
||||
def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map):
|
||||
"""
|
||||
Cluster pixels in tcl_map based on quads.
|
||||
"""
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||
if instance_count == 1:
|
||||
return instance_count, instance_label_map
|
||||
|
||||
# predict text center
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
n = xy_text.shape[0]
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
pred_tc = xy_text - tco
|
||||
|
||||
# get gt text center
|
||||
m = quads.shape[0]
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
|
||||
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||
return instance_count, instance_label_map
|
||||
|
||||
def estimate_sample_pts_num(self, quad, xy_text):
|
||||
"""
|
||||
Estimate sample points number.
|
||||
"""
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
|
||||
dense_sample_pts_num = max(2, int(ew))
|
||||
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||
|
||||
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||
return sample_pts_num
|
||||
|
||||
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||
"""
|
||||
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||
"""
|
||||
# restore quad
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||
dets = self.nms(dets)
|
||||
if dets.shape[0] == 0:
|
||||
return []
|
||||
quads = dets[:, :-1].reshape(-1, 4, 2)
|
||||
|
||||
# Compute quad area
|
||||
quad_areas = []
|
||||
for quad in quads:
|
||||
quad_areas.append(-self.quad_area(quad))
|
||||
|
||||
# instance segmentation
|
||||
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
|
||||
# restore single poly with tcl instance.
|
||||
poly_list = []
|
||||
for instance_idx in range(1, instance_count):
|
||||
xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1]
|
||||
quad = quads[instance_idx - 1]
|
||||
q_area = quad_areas[instance_idx - 1]
|
||||
if q_area < 5:
|
||||
continue
|
||||
|
||||
#
|
||||
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||
min_len = min(len1, len2)
|
||||
if min_len < 3:
|
||||
continue
|
||||
|
||||
# filter small CC
|
||||
if xy_text.shape[0] <= 0:
|
||||
continue
|
||||
|
||||
# filter low confidence instance
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
continue
|
||||
|
||||
# sort xy_text
|
||||
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# Sample pts in tcl map
|
||||
if self.sample_pts_num == 0:
|
||||
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||
else:
|
||||
sample_pts_num = self.sample_pts_num
|
||||
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
|
||||
point_pair_list = []
|
||||
for x, y in xy_center_line:
|
||||
# get corresponding offset
|
||||
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
# ndarry: (x, 2), expand poly along width
|
||||
detected_poly = self.point_pair2poly(point_pair_list)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
poly_list.append(detected_poly)
|
||||
|
||||
return poly_list
|
||||
|
||||
def __call__(self, outs_dict, ratio_list):
|
||||
score_list = outs_dict['f_score']
|
||||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
tco_list = outs_dict['f_tco']
|
||||
|
||||
img_num = len(ratio_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0))
|
||||
p_border = border_list[ino].transpose((1,2,0))
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0))
|
||||
p_tco = tco_list[ino].transpose((1,2,0))
|
||||
# print(p_score.shape, p_border.shape, p_tvo.shape, p_tco.shape)
|
||||
ratio_h, ratio_w, src_h, src_w = ratio_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||
|
||||
poly_lists.append(poly_list)
|
||||
|
||||
return poly_lists
|
||||
|
|
@ -21,7 +21,6 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
|
||||
from .utility import initial_logger
|
||||
|
@ -110,17 +109,20 @@ def init_model(config, program, exe):
|
|||
"""
|
||||
checkpoints = config['Global'].get('checkpoints')
|
||||
if checkpoints:
|
||||
path = checkpoints
|
||||
fluid.load(program, path, exe)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
return
|
||||
|
||||
pretrain_weights = config['Global'].get('pretrain_weights')
|
||||
if pretrain_weights:
|
||||
path = pretrain_weights
|
||||
load_params(exe, program, path)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
return
|
||||
if os.path.exists(checkpoints + '.pdparams'):
|
||||
path = checkpoints
|
||||
fluid.load(program, path, exe)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
else:
|
||||
raise ValueError("Model checkpoints {} does not exists,"
|
||||
"check if you lost the file prefix.".format(
|
||||
checkpoints + '.pdparams'))
|
||||
else:
|
||||
pretrain_weights = config['Global'].get('pretrain_weights')
|
||||
if pretrain_weights:
|
||||
path = pretrain_weights
|
||||
load_params(exe, program, path)
|
||||
logger.info("Finish initing model from {}".format(path))
|
||||
|
||||
|
||||
def save_model(program, model_path):
|
||||
|
|
|
@ -58,7 +58,7 @@ def main():
|
|||
program.check_gpu(use_gpu)
|
||||
|
||||
alg = config['Global']['algorithm']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
assert alg in ['EAST', 'DB', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SAST']
|
||||
if alg in ['Rosetta', 'CRNN', 'STARNet', 'RARE']:
|
||||
config['Global']['char_ops'] = CharacterOps(config['Global'])
|
||||
|
||||
|
@ -75,7 +75,7 @@ def main():
|
|||
|
||||
init_model(config, eval_program, exe)
|
||||
|
||||
if alg in ['EAST', 'DB']:
|
||||
if alg in ['EAST', 'DB', 'SAST']:
|
||||
eval_reader = reader_main(config=config, mode="eval")
|
||||
eval_info_dict = {'program':eval_program,\
|
||||
'reader':eval_reader,\
|
||||
|
|
|
@ -88,8 +88,8 @@ class DetectionIoUEvaluator(object):
|
|||
points = gt[n]['points']
|
||||
# transcription = gt[n]['text']
|
||||
dontCare = gt[n]['ignore']
|
||||
points = Polygon(points)
|
||||
points = points.buffer(0)
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
|
@ -105,8 +105,8 @@ class DetectionIoUEvaluator(object):
|
|||
|
||||
for n in range(len(pred)):
|
||||
points = pred[n]['points']
|
||||
points = Polygon(points)
|
||||
points = points.buffer(0)
|
||||
# points = Polygon(points)
|
||||
# points = points.buffer(0)
|
||||
if not Polygon(points).is_valid or not Polygon(points).is_simple:
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue