PaddleOCR/ppocr/utils/e2e_utils/visual.py

344 lines
11 KiB
Python

import os
import numpy as np
import cv2
import time
def visualize_e2e_result(im_fn, poly_list, seq_strs, src_im):
"""
"""
result_path = './out'
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
valid_set = 'partvgg'
gt_dir = "/Users/hongyongjie/Downloads/part_vgg_synth/train"
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
if valid_set == 'partvgg':
tokens = line.strip().split('\t')[0].split(',')
# tokens = line.strip().split(',')
coords = tokens[:]
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, 4, 2)
elif valid_set == 'totaltext':
tokens = line.strip().split('\t')[0].split(',')
coords = tokens[:]
coords_len = len(coords) / 2
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 0, 0),
thickness=2)
for detected_poly, recognized_str in zip(poly_list, seq_strs):
cv2.polylines(
vis_det_img,
np.array(detected_poly[np.newaxis, ...]).astype(np.int32),
isClosed=True,
color=(0, 0, 255),
thickness=2)
cv2.putText(
vis_det_img,
recognized_str,
org=(int(detected_poly[0, 0]), int(detected_poly[0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_detection.jpg".format(result_path, im_prefix),
vis_det_img)
def visualization_output(src_image,
f_tcl,
f_chars,
output_dir,
image_prefix=None):
"""
"""
# restore BGR image, CHW -> HWC
im_mean = [0.485, 0.456, 0.406]
im_std = [0.229, 0.224, 0.225]
im_mean = np.array(im_mean).reshape((3, 1, 1))
im_std = np.array(im_std).reshape((3, 1, 1))
src_image *= im_std
src_image += im_mean
src_image = src_image.transpose([1, 2, 0])
src_image = src_image[:, :, ::-1] * 255 # BGR -> RGB
H, W, _ = src_image.shape
file_prefix = image_prefix if image_prefix is not None else str(
int(time.time() * 1000))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# visualization f_tcl
tcl_file_name = os.path.join(output_dir, file_prefix + '_0_tcl.jpg')
vis_tcl_img = src_image.copy()
f_tcl_resized = cv2.resize(f_tcl, dsize=(W, H))
vis_tcl_img[:, :, 1] = f_tcl_resized * 255
cv2.imwrite(tcl_file_name, vis_tcl_img)
# visualization char maps
vis_char_img = src_image.copy()
# CHW -> HWC
char_file_name = os.path.join(output_dir, file_prefix + '_1_chars.jpg')
f_chars = np.argmax(f_chars, axis=2)[:, :, np.newaxis].astype('float32')
f_chars[f_chars < 95] = 1.0
f_chars[f_chars == 95] = 0.0
f_chars_resized = cv2.resize(f_chars, dsize=(W, H))
vis_char_img[:, :, 1] = f_chars_resized * 255
cv2.imwrite(char_file_name, vis_char_img)
def visualize_point_result(im_fn, point_list, point_pair_list, src_im, gt_dir,
result_path):
"""
"""
im_basename = os.path.basename(im_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
vis_det_img = src_im.copy()
# draw gt bbox on the image.
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for line in lines:
tokens = line.strip().split('\t')
coords = tokens[0].split(',')
coords_len = len(coords)
coords = list(map(float, coords))
gt_poly = np.array(coords).reshape(1, coords_len / 2, 2)
cv2.polylines(
vis_det_img,
np.array(gt_poly).astype(np.int32),
isClosed=True,
color=(255, 255, 255),
thickness=1)
for point, point_pair in zip(point_list, point_pair_list):
cv2.line(
vis_det_img,
tuple(point_pair[0]),
tuple(point_pair[1]), (0, 255, 255),
thickness=1)
cv2.circle(vis_det_img, tuple(point), 2, (0, 0, 255))
cv2.circle(vis_det_img, tuple(point_pair[0]), 2, (255, 0, 0))
cv2.circle(vis_det_img, tuple(point_pair[1]), 2, (0, 255, 0))
if not os.path.exists(result_path):
os.makedirs(result_path)
cv2.imwrite("{}/{}_border_points.jpg".format(result_path, im_prefix),
vis_det_img)
def resize_image(im, max_side_len=512):
"""
resize image to a size multiple of max_stride which is required by the network
:param im: the resized image
:param max_side_len: limit of max image size to avoid out of memory in gpu
:return: the resized image and the resize ratio
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_min(im, max_side_len=512):
"""
"""
print('--> Using resize_image_min')
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
ratio = float(max_side_len) / resize_w
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def resize_image_for_totaltext(im, max_side_len=512):
"""
"""
h, w, _ = im.shape
resize_w = w
resize_h = h
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
# Fix the longer side
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
###
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
max_stride = 128
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
im = cv2.resize(im, (int(resize_w), int(resize_h)))
ratio_h = resize_h / float(h)
ratio_w = resize_w / float(w)
return im, (ratio_h, ratio_w)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
pair_length_list = []
for point_pair in point_pair_list:
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
pair_length_list.append(pair_length)
pair_length_list = np.array(pair_length_list)
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2), pair_info
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
"""
Generate shrink_quad_along_width.
"""
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + \
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def norm2(x, axis=None):
if axis:
return np.sqrt(np.sum(x**2, axis=axis))
return np.sqrt(np.sum(x**2))
def cos(p1, p2):
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
def generate_direction_info(image_fn,
H,
W,
ratio_h,
ratio_w,
max_length=640,
out_scale=4,
gt_dir=None):
"""
"""
im_basename = os.path.basename(image_fn)
im_prefix = im_basename[:im_basename.rfind('.')]
instance_direction_map = np.zeros(shape=[H // out_scale, W // out_scale, 3])
if gt_dir is None:
gt_dir = '/home/vis/huangzuming/data/SYNTH_DATA/part_vgg_synth_icdar/processed/val/poly'
# get gt label map
text_path = os.path.join(gt_dir, im_prefix + '.txt')
fid = open(text_path, 'r')
lines = [line.strip() for line in fid.readlines()]
for label_idx, line in enumerate(lines, start=1):
coords, txt = line.strip().split('\t')
if txt == '###':
continue
tokens = coords.strip().split(',')
coords = list(map(float, tokens))
poly = np.array(coords).reshape(4, 2) * np.array(
[ratio_w, ratio_h]).reshape(1, 2) / out_scale
mid_idx = poly.shape[0] // 2
direct_vector = (
(poly[mid_idx] + poly[mid_idx - 1]) - (poly[0] + poly[-1])) / 2.0
direct_vector /= len(txt)
# l2_distance = norm2(direct_vector)
# avg_char_distance = l2_distance / len(txt)
avg_char_distance = 1.0
direct_label = (direct_vector[0], direct_vector[1], avg_char_distance)
cv2.fillPoly(instance_direction_map,
poly.round().astype(np.int32)[np.newaxis, :, :],
direct_label)
instance_direction_map = instance_direction_map.transpose([2, 0, 1])
return instance_direction_map[:2, ...]