Merge pull request #1041 from LDOUBLEV/fixocr

add code comment
This commit is contained in:
MissPenguin 2020-10-28 15:58:27 +08:00 committed by GitHub
commit 251555f74e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 40 additions and 1 deletions

View File

@ -28,14 +28,22 @@ from .make_border_map import MakeBorderMap
class DBProcessTrain(object): class DBProcessTrain(object):
""" """
DB pre-process for Train mode The pre-process of DB for train mode
""" """
def __init__(self, params): def __init__(self, params):
"""
:param params: dict of params
"""
self.img_set_dir = params['img_set_dir'] self.img_set_dir = params['img_set_dir']
self.image_shape = params['image_shape'] self.image_shape = params['image_shape']
def order_points_clockwise(self, pts): def order_points_clockwise(self, pts):
"""
Sort the points in the box clockwise
:param pts: points with shape [4, 2]
:return: sorted points
"""
rect = np.zeros((4, 2), dtype="float32") rect = np.zeros((4, 2), dtype="float32")
s = pts.sum(axis=1) s = pts.sum(axis=1)
rect[0] = pts[np.argmin(s)] rect[0] = pts[np.argmin(s)]
@ -46,6 +54,12 @@ class DBProcessTrain(object):
return rect return rect
def make_data_dict(self, imgvalue, entry): def make_data_dict(self, imgvalue, entry):
"""
create input dict
:param imgvalue: input image
:param entry: dict of annotations information
:return: created dict of input data information
"""
boxes = [] boxes = []
texts = [] texts = []
ignores = [] ignores = []
@ -71,6 +85,11 @@ class DBProcessTrain(object):
return data return data
def NormalizeImage(self, data): def NormalizeImage(self, data):
"""
Normalize input image
:param data: input dict
:return: new dict with normalized image
"""
im = data['image'] im = data['image']
img_mean = [0.485, 0.456, 0.406] img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225] img_std = [0.229, 0.224, 0.225]
@ -84,6 +103,11 @@ class DBProcessTrain(object):
return data return data
def FilterKeys(self, data): def FilterKeys(self, data):
"""
Filter keys
:param data: dict
:return:
"""
filter_keys = ['polys', 'texts', 'ignore_tags', 'shape'] filter_keys = ['polys', 'texts', 'ignore_tags', 'shape']
for key in filter_keys: for key in filter_keys:
if key in data: if key in data:
@ -91,6 +115,11 @@ class DBProcessTrain(object):
return data return data
def convert_label_infor(self, label_infor): def convert_label_infor(self, label_infor):
"""
encode annotations using json.loads
:param label_infor: string
:return: (image, encoded annotations)
"""
label_infor = label_infor.decode() label_infor = label_infor.decode()
label_infor = label_infor.encode('utf-8').decode('utf-8-sig') label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
substr = label_infor.strip("\n").split("\t") substr = label_infor.strip("\n").split("\t")
@ -184,6 +213,11 @@ class DBProcessTest(object):
return im, (ratio_h, ratio_w) return im, (ratio_h, ratio_w)
def resize_image_type1(self, im): def resize_image_type1(self, im):
"""
resize image to a size self.image_shape
:param im: input image
:return: normalized image and resize ratio
"""
resize_h, resize_w = self.image_shape resize_h, resize_w = self.image_shape
ori_h, ori_w = im.shape[:2] # (h, w, c) ori_h, ori_w = im.shape[:2] # (h, w, c)
im = cv2.resize(im, (int(resize_w), int(resize_h))) im = cv2.resize(im, (int(resize_w), int(resize_h)))
@ -192,6 +226,11 @@ class DBProcessTest(object):
return im, (ratio_h, ratio_w) return im, (ratio_h, ratio_w)
def normalize(self, im): def normalize(self, im):
"""
Normalize image
:param im: input image
:return: Normalized image
"""
img_mean = [0.485, 0.456, 0.406] img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225] img_std = [0.229, 0.224, 0.225]
im = im.astype(np.float32, copy=False) im = im.astype(np.float32, copy=False)