fix sast process

This commit is contained in:
LDOUBLEV 2021-01-21 11:11:03 +08:00
parent 18669cc344
commit 16bd2dd093
1 changed files with 188 additions and 103 deletions

View File

@ -24,11 +24,11 @@ __all__ = ['SASTProcessTrain']
class SASTProcessTrain(object): class SASTProcessTrain(object):
def __init__(self, def __init__(self,
image_shape = [512, 512], image_shape=[512, 512],
min_crop_size = 24, min_crop_size=24,
min_crop_side_ratio = 0.3, min_crop_side_ratio=0.3,
min_text_size = 10, min_text_size=10,
max_text_size = 512, max_text_size=512,
**kwargs): **kwargs):
self.input_size = image_shape[1] self.input_size = image_shape[1]
self.min_crop_size = min_crop_size self.min_crop_size = min_crop_size
@ -42,12 +42,10 @@ class SASTProcessTrain(object):
:param poly: :param poly:
:return: :return:
""" """
edge = [ edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]), (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]), (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]), (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
]
return np.sum(edge) / 2. return np.sum(edge) / 2.
def gen_quad_from_poly(self, poly): def gen_quad_from_poly(self, poly):
@ -57,7 +55,8 @@ class SASTProcessTrain(object):
point_num = poly.shape[0] point_num = poly.shape[0]
min_area_quad = np.zeros((4, 2), dtype=np.float32) min_area_quad = np.zeros((4, 2), dtype=np.float32)
if True: if True:
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0] center_point = rect[0]
box = np.array(cv2.boxPoints(rect)) box = np.array(cv2.boxPoints(rect))
@ -102,23 +101,33 @@ class SASTProcessTrain(object):
if p_area > 0: if p_area > 0:
if tag == False: if tag == False:
print('poly in wrong direction') print('poly in wrong direction')
tag = True # reversed cases should be ignore tag = True # reversed cases should be ignore
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :] poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
1), :]
quad = quad[(0, 3, 2, 1), :] quad = quad[(0, 3, 2, 1), :]
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2]) len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]) quad[2])
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
quad[2])
hv_tag = 1 hv_tag = 1
if len_w * 2.0 < len_h: if len_w * 2.0 < len_h:
hv_tag = 0 hv_tag = 0
validated_polys.append(poly) validated_polys.append(poly)
validated_tags.append(tag) validated_tags.append(tag)
hv_tags.append(hv_tag) hv_tags.append(hv_tag)
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags) return np.array(validated_polys), np.array(validated_tags), np.array(
hv_tags)
def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25): def crop_area(self,
im,
polys,
tags,
hv_tags,
crop_background=False,
max_tries=25):
""" """
make random crop from the input image make random crop from the input image
:param im: :param im:
@ -137,10 +146,10 @@ class SASTProcessTrain(object):
poly = np.round(poly, decimals=0).astype(np.int32) poly = np.round(poly, decimals=0).astype(np.int32)
minx = np.min(poly[:, 0]) minx = np.min(poly[:, 0])
maxx = np.max(poly[:, 0]) maxx = np.max(poly[:, 0])
w_array[minx + pad_w: maxx + pad_w] = 1 w_array[minx + pad_w:maxx + pad_w] = 1
miny = np.min(poly[:, 1]) miny = np.min(poly[:, 1])
maxy = np.max(poly[:, 1]) maxy = np.max(poly[:, 1])
h_array[miny + pad_h: maxy + pad_h] = 1 h_array[miny + pad_h:maxy + pad_h] = 1
# ensure the cropped area not across a text # ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0] h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0] w_axis = np.where(w_array == 0)[0]
@ -166,17 +175,18 @@ class SASTProcessTrain(object):
if polys.shape[0] != 0: if polys.shape[0] != 0:
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \ poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax) & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0] selected_polys = np.where(
np.sum(poly_axis_in_area, axis=1) == 4)[0]
else: else:
selected_polys = [] selected_polys = []
if len(selected_polys) == 0: if len(selected_polys) == 0:
# no text in this area # no text in this area
if crop_background: if crop_background:
return im[ymin : ymax + 1, xmin : xmax + 1, :], \ return im[ymin : ymax + 1, xmin : xmax + 1, :], \
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
else: else:
continue continue
im = im[ymin: ymax + 1, xmin: xmax + 1, :] im = im[ymin:ymax + 1, xmin:xmax + 1, :]
polys = polys[selected_polys] polys = polys[selected_polys]
tags = tags[selected_polys] tags = tags[selected_polys]
hv_tags = hv_tags[selected_polys] hv_tags = hv_tags[selected_polys]
@ -192,18 +202,28 @@ class SASTProcessTrain(object):
width_list = [] width_list = []
height_list = [] height_list = []
for quad in poly_quads: for quad in poly_quads:
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0 quad_w = (np.linalg.norm(quad[0] - quad[1]) +
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 np.linalg.norm(quad[2] - quad[3])) / 2.0
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
width_list.append(quad_w) width_list.append(quad_w)
height_list.append(quad_h) height_list.append(quad_h)
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0) norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0) average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
for quad in poly_quads: for quad in poly_quads:
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0 direct_vector_full = (
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)])) direct_vector = direct_vector_full / (
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label) np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
direction_label = tuple(
map(float, [
direct_vector[0], direct_vector[1], 1.0 / (average_height +
1e-6)
]))
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
return direction_map return direction_map
def calculate_average_height(self, poly_quads): def calculate_average_height(self, poly_quads):
@ -211,13 +231,19 @@ class SASTProcessTrain(object):
""" """
height_list = [] height_list = []
for quad in poly_quads: for quad in poly_quads:
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0 quad_h = (np.linalg.norm(quad[0] - quad[3]) +
np.linalg.norm(quad[2] - quad[1])) / 2.0
height_list.append(quad_h) height_list.append(quad_h)
average_height = max(sum(height_list) / len(height_list), 1.0) average_height = max(sum(height_list) / len(height_list), 1.0)
return average_height return average_height
def generate_tcl_label(self, hw, polys, tags, ds_ratio, def generate_tcl_label(self,
tcl_ratio=0.3, shrink_ratio_of_width=0.15): hw,
polys,
tags,
ds_ratio,
tcl_ratio=0.3,
shrink_ratio_of_width=0.15):
""" """
Generate polygon. Generate polygon.
""" """
@ -225,21 +251,30 @@ class SASTProcessTrain(object):
h, w = int(h * ds_ratio), int(w * ds_ratio) h, w = int(h * ds_ratio), int(w * ds_ratio)
polys = polys * ds_ratio polys = polys * ds_ratio
score_map = np.zeros((h, w,), dtype=np.float32) score_map = np.zeros(
(
h,
w, ), dtype=np.float32)
tbo_map = np.zeros((h, w, 5), dtype=np.float32) tbo_map = np.zeros((h, w, 5), dtype=np.float32)
training_mask = np.ones((h, w,), dtype=np.float32) training_mask = np.ones(
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32) (
h,
w, ), dtype=np.float32)
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
[1, 1, 3]).astype(np.float32)
for poly_idx, poly_tag in enumerate(zip(polys, tags)): for poly_idx, poly_tag in enumerate(zip(polys, tags)):
poly = poly_tag[0] poly = poly_tag[0]
tag = poly_tag[1] tag = poly_tag[1]
# generate min_area_quad # generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[1] - min_area_quad[2])) np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
np.linalg.norm(min_area_quad[2] - min_area_quad[3])) min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \ if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio: or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
@ -247,25 +282,37 @@ class SASTProcessTrain(object):
if tag: if tag:
# continue # continue
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15) cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else: else:
tcl_poly = self.poly2tcl(poly, tcl_ratio) tcl_poly = self.poly2tcl(poly, tcl_ratio)
tcl_quads = self.poly2quads(tcl_poly) tcl_quads = self.poly2quads(tcl_poly)
poly_quads = self.poly2quads(poly) poly_quads = self.poly2quads(poly)
# stcl map # stcl map
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width, stcl_quads, quad_index = self.shrink_poly_along_width(
expand_height_ratio=1.0 / tcl_ratio) tcl_quads,
shrink_ratio_of_width=shrink_ratio_of_width,
expand_height_ratio=1.0 / tcl_ratio)
# generate tcl map # generate tcl map
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0) cv2.fillPoly(score_map,
np.round(stcl_quads).astype(np.int32), 1.0)
# generate tbo map # generate tbo map
for idx, quad in enumerate(stcl_quads): for idx, quad in enumerate(stcl_quads):
quad_mask = np.zeros((h, w), dtype=np.float32) quad_mask = np.zeros((h, w), dtype=np.float32)
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0) quad_mask = cv2.fillPoly(
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map) quad_mask,
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
quad_mask, tbo_map)
return score_map, tbo_map, training_mask return score_map, tbo_map, training_mask
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25): def generate_tvo_and_tco(self,
hw,
polys,
tags,
tcl_ratio=0.3,
ds_ratio=0.25):
""" """
Generate tcl map, tvo map and tbo map. Generate tcl map, tvo map and tbo map.
""" """
@ -297,35 +344,44 @@ class SASTProcessTrain(object):
# generate min_area_quad # generate min_area_quad
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly) min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) + min_area_quad_h = 0.5 * (
np.linalg.norm(min_area_quad[1] - min_area_quad[2])) np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) + np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
np.linalg.norm(min_area_quad[2] - min_area_quad[3])) min_area_quad_w = 0.5 * (
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
# generate tcl map and text, 128 * 128 # generate tcl map and text, 128 * 128
tcl_poly = self.poly2tcl(poly, tcl_ratio) tcl_poly = self.poly2tcl(poly, tcl_ratio)
# generate poly_tv_xy_map # generate poly_tv_xy_map
for idx in range(4): for idx in range(4):
cv2.fillPoly(poly_tv_xy_map[2 * idx], cv2.fillPoly(
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), poly_tv_xy_map[2 * idx],
float(min(max(min_area_quad[idx, 0], 0), w))) np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1], float(min(max(min_area_quad[idx, 0], 0), w)))
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), cv2.fillPoly(
float(min(max(min_area_quad[idx, 1], 0), h))) poly_tv_xy_map[2 * idx + 1],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(min(max(min_area_quad[idx, 1], 0), h)))
# generate poly_tc_xy_map # generate poly_tc_xy_map
for idx in range(2): for idx in range(2):
cv2.fillPoly(poly_tc_xy_map[idx], cv2.fillPoly(
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx])) poly_tc_xy_map[idx],
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(center_point[idx]))
# generate poly_short_edge_map # generate poly_short_edge_map
cv2.fillPoly(poly_short_edge_map, cv2.fillPoly(
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), poly_short_edge_map,
float(max(min(min_area_quad_h, min_area_quad_w), 1.0))) np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
# generate poly_mask and training_mask # generate poly_mask and training_mask
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1) cv2.fillPoly(poly_mask,
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
1)
tvo_map *= poly_mask tvo_map *= poly_mask
tvo_map[:8] -= poly_tv_xy_map tvo_map[:8] -= poly_tv_xy_map
@ -356,7 +412,8 @@ class SASTProcessTrain(object):
elif point_num > 4: elif point_num > 4:
vector_1 = poly[0] - poly[1] vector_1 = poly[0] - poly[1]
vector_2 = poly[1] - poly[2] vector_2 = poly[1] - poly[2]
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6) cos_theta = np.dot(vector_1, vector_2) / (
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
theta = np.arccos(np.round(cos_theta, decimals=4)) theta = np.arccos(np.round(cos_theta, decimals=4))
if abs(theta) > (70 / 180 * math.pi): if abs(theta) > (70 / 180 * math.pi):
@ -374,7 +431,8 @@ class SASTProcessTrain(object):
min_area_quad = poly min_area_quad = poly
center_point = np.sum(poly, axis=0) / 4 center_point = np.sum(poly, axis=0) / 4
else: else:
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation) rect = cv2.minAreaRect(poly.astype(
np.int32)) # (center (x,y), (width, height), angle of rotation)
center_point = rect[0] center_point = rect[0]
box = np.array(cv2.boxPoints(rect)) box = np.array(cv2.boxPoints(rect))
@ -394,16 +452,23 @@ class SASTProcessTrain(object):
return min_area_quad, center_point return min_area_quad, center_point
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.): def shrink_quad_along_width(self,
quad,
begin_width_ratio=0.,
end_width_ratio=1.):
""" """
Generate shrink_quad_along_width. Generate shrink_quad_along_width.
""" """
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32) ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0): def shrink_poly_along_width(self,
quads,
shrink_ratio_of_width,
expand_height_ratio=1.0):
""" """
shrink poly with given length. shrink poly with given length.
""" """
@ -421,22 +486,28 @@ class SASTProcessTrain(object):
upper_edge_list.append(upper_edge_len) upper_edge_list.append(upper_edge_len)
# length of left edge and right edge. # length of left edge and right edge.
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio left_length = np.linalg.norm(quads[0][0] - quads[0][
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio 3]) * expand_height_ratio
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
2]) * expand_height_ratio
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width shrink_length = min(left_length, right_length,
sum(upper_edge_list)) * shrink_ratio_of_width
# shrinking length # shrinking length
upper_len_left = shrink_length upper_len_left = shrink_length
upper_len_right = sum(upper_edge_list) - shrink_length upper_len_right = sum(upper_edge_list) - shrink_length
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left) left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1) left_quad = self.shrink_quad_along_width(
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right) right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio) right_quad = self.shrink_quad_along_width(
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
out_quad_list = [] out_quad_list = []
if left_idx == right_idx: if left_idx == right_idx:
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]]) out_quad_list.append(
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
else: else:
out_quad_list.append(left_quad) out_quad_list.append(left_quad)
for idx in range(left_idx + 1, right_idx): for idx in range(left_idx + 1, right_idx):
@ -500,7 +571,8 @@ class SASTProcessTrain(object):
""" """
Generate center line by poly clock-wise point. (4, 2) Generate center line by poly clock-wise point. (4, 2)
""" """
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]]) return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
@ -509,12 +581,14 @@ class SASTProcessTrain(object):
""" """
Generate center line by poly clock-wise point. Generate center line by poly clock-wise point.
""" """
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32) ratio_pair = np.array(
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
tcl_poly = np.zeros_like(poly) tcl_poly = np.zeros_like(poly)
point_num = poly.shape[0] point_num = poly.shape[0]
for idx in range(point_num // 2): for idx in range(point_num // 2):
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
) * ratio_pair
tcl_poly[idx] = point_pair[0] tcl_poly[idx] = point_pair[0]
tcl_poly[point_num - 1 - idx] = point_pair[1] tcl_poly[point_num - 1 - idx] = point_pair[1]
return tcl_poly return tcl_poly
@ -527,8 +601,10 @@ class SASTProcessTrain(object):
up_line = self.line_cross_two_point(quad[0], quad[1]) up_line = self.line_cross_two_point(quad[0], quad[1])
lower_line = self.line_cross_two_point(quad[3], quad[2]) lower_line = self.line_cross_two_point(quad[3], quad[2])
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) np.linalg.norm(quad[1] - quad[2]))
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
np.linalg.norm(quad[2] - quad[3]))
# average angle of left and right line. # average angle of left and right line.
angle = self.average_angle(quad) angle = self.average_angle(quad)
@ -565,7 +641,8 @@ class SASTProcessTrain(object):
quad_num = point_num // 2 - 1 quad_num = point_num // 2 - 1
for idx in range(quad_num): for idx in range(quad_num):
# reshape and adjust to clock-wise # reshape and adjust to clock-wise
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]]) quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
).reshape(4, 2)[[0, 2, 3, 1]])
return np.array(quad_list) return np.array(quad_list)
@ -579,7 +656,8 @@ class SASTProcessTrain(object):
return None return None
h, w, _ = im.shape h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w)) text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w))
if text_polys.shape[0] == 0: if text_polys.shape[0] == 0:
return None return None
@ -591,7 +669,7 @@ class SASTProcessTrain(object):
if np.random.rand() < 0.5: if np.random.rand() < 0.5:
asp_scale = 1.0 / asp_scale asp_scale = 1.0 / asp_scale
asp_scale = math.sqrt(asp_scale) asp_scale = math.sqrt(asp_scale)
asp_wx = asp_scale asp_wx = asp_scale
asp_hy = 1.0 / asp_scale asp_hy = 1.0 / asp_scale
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy) im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
@ -610,7 +688,7 @@ class SASTProcessTrain(object):
#no background #no background
im, text_polys, text_tags, hv_tags = self.crop_area(im, \ im, text_polys, text_tags, hv_tags = self.crop_area(im, \
text_polys, text_tags, hv_tags, crop_background=False) text_polys, text_tags, hv_tags, crop_background=False)
if text_polys.shape[0] == 0: if text_polys.shape[0] == 0:
return None return None
#continue for all ignore case #continue for all ignore case
@ -621,17 +699,18 @@ class SASTProcessTrain(object):
return None return None
#resize image #resize image
std_ratio = float(self.input_size) / max(new_w, new_h) std_ratio = float(self.input_size) / max(new_w, new_h)
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0]) rand_scales = np.array(
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
rz_scale = std_ratio * np.random.choice(rand_scales) rz_scale = std_ratio * np.random.choice(rand_scales)
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale) im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
text_polys[:, :, 0] *= rz_scale text_polys[:, :, 0] *= rz_scale
text_polys[:, :, 1] *= rz_scale text_polys[:, :, 1] *= rz_scale
#add gaussian blur #add gaussian blur
if np.random.rand() < 0.1 * 0.5: if np.random.rand() < 0.1 * 0.5:
ks = np.random.permutation(5)[0] + 1 ks = np.random.permutation(5)[0] + 1
ks = int(ks/2)*2 + 1 ks = int(ks / 2) * 2 + 1
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0) im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
#add brighter #add brighter
if np.random.rand() < 0.1 * 0.5: if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 + np.random.rand() * 0.5) im = im * (1.0 + np.random.rand() * 0.5)
@ -640,13 +719,14 @@ class SASTProcessTrain(object):
if np.random.rand() < 0.1 * 0.5: if np.random.rand() < 0.1 * 0.5:
im = im * (1.0 - np.random.rand() * 0.5) im = im * (1.0 - np.random.rand() * 0.5)
im = np.clip(im, 0.0, 255.0) im = np.clip(im, 0.0, 255.0)
# Padding the im to [input_size, input_size] # Padding the im to [input_size, input_size]
new_h, new_w, _ = im.shape new_h, new_w, _ = im.shape
if min(new_w, new_h) < self.input_size * 0.5: if min(new_w, new_h) < self.input_size * 0.5:
return None return None
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32) im_padded = np.ones(
(self.input_size, self.input_size, 3), dtype=np.float32)
im_padded[:, :, 2] = 0.485 * 255 im_padded[:, :, 2] = 0.485 * 255
im_padded[:, :, 1] = 0.456 * 255 im_padded[:, :, 1] = 0.456 * 255
im_padded[:, :, 0] = 0.406 * 255 im_padded[:, :, 0] = 0.406 * 255
@ -661,24 +741,29 @@ class SASTProcessTrain(object):
sw = int(np.random.rand() * del_w) sw = int(np.random.rand() * del_w)
# Padding # Padding
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy() im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
text_polys[:, :, 0] += sw text_polys[:, :, 0] += sw
text_polys[:, :, 1] += sh text_polys[:, :, 1] += sh
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size), score_map, border_map, training_mask = self.generate_tcl_label(
text_polys, text_tags, 0.25) (self.input_size, self.input_size), text_polys, text_tags, 0.25)
# SAST head # SAST head
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25) tvo_map, tco_map = self.generate_tvo_and_tco(
(self.input_size, self.input_size),
text_polys,
text_tags,
tcl_ratio=0.3,
ds_ratio=0.25)
# print("test--------tvo_map shape:", tvo_map.shape) # print("test--------tvo_map shape:", tvo_map.shape)
im_padded[:, :, 2] -= 0.485 * 255 im_padded[:, :, 2] -= 0.485 * 255
im_padded[:, :, 1] -= 0.456 * 255 im_padded[:, :, 1] -= 0.456 * 255
im_padded[:, :, 0] -= 0.406 * 255 im_padded[:, :, 0] -= 0.406 * 255
im_padded[:, :, 2] /= (255.0 * 0.229) im_padded[:, :, 2] /= (255.0 * 0.229)
im_padded[:, :, 1] /= (255.0 * 0.224) im_padded[:, :, 1] /= (255.0 * 0.224)
im_padded[:, :, 0] /= (255.0 * 0.225) im_padded[:, :, 0] /= (255.0 * 0.225)
im_padded = im_padded.transpose((2, 0, 1)) im_padded = im_padded.transpose((2, 0, 1))
data['image'] = im_padded[::-1, :, :] data['image'] = im_padded[::-1, :, :]
data['score_map'] = score_map[np.newaxis, :, :] data['score_map'] = score_map[np.newaxis, :, :]
@ -686,4 +771,4 @@ class SASTProcessTrain(object):
data['training_mask'] = training_mask[np.newaxis, :, :] data['training_mask'] = training_mask[np.newaxis, :, :]
data['tvo_map'] = tvo_map.transpose((2, 0, 1)) data['tvo_map'] = tvo_map.transpose((2, 0, 1))
data['tco_map'] = tco_map.transpose((2, 0, 1)) data['tco_map'] = tco_map.transpose((2, 0, 1))
return data return data