commit
251555f74e
|
@ -28,14 +28,22 @@ from .make_border_map import MakeBorderMap
|
|||
|
||||
class DBProcessTrain(object):
|
||||
"""
|
||||
DB pre-process for Train mode
|
||||
The pre-process of DB for train mode
|
||||
"""
|
||||
|
||||
def __init__(self, params):
|
||||
"""
|
||||
:param params: dict of params
|
||||
"""
|
||||
self.img_set_dir = params['img_set_dir']
|
||||
self.image_shape = params['image_shape']
|
||||
|
||||
def order_points_clockwise(self, pts):
|
||||
"""
|
||||
Sort the points in the box clockwise
|
||||
:param pts: points with shape [4, 2]
|
||||
:return: sorted points
|
||||
"""
|
||||
rect = np.zeros((4, 2), dtype="float32")
|
||||
s = pts.sum(axis=1)
|
||||
rect[0] = pts[np.argmin(s)]
|
||||
|
@ -46,6 +54,12 @@ class DBProcessTrain(object):
|
|||
return rect
|
||||
|
||||
def make_data_dict(self, imgvalue, entry):
|
||||
"""
|
||||
create input dict
|
||||
:param imgvalue: input image
|
||||
:param entry: dict of annotations information
|
||||
:return: created dict of input data information
|
||||
"""
|
||||
boxes = []
|
||||
texts = []
|
||||
ignores = []
|
||||
|
@ -71,6 +85,11 @@ class DBProcessTrain(object):
|
|||
return data
|
||||
|
||||
def NormalizeImage(self, data):
|
||||
"""
|
||||
Normalize input image
|
||||
:param data: input dict
|
||||
:return: new dict with normalized image
|
||||
"""
|
||||
im = data['image']
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
|
@ -84,6 +103,11 @@ class DBProcessTrain(object):
|
|||
return data
|
||||
|
||||
def FilterKeys(self, data):
|
||||
"""
|
||||
Filter keys
|
||||
:param data: dict
|
||||
:return:
|
||||
"""
|
||||
filter_keys = ['polys', 'texts', 'ignore_tags', 'shape']
|
||||
for key in filter_keys:
|
||||
if key in data:
|
||||
|
@ -91,6 +115,11 @@ class DBProcessTrain(object):
|
|||
return data
|
||||
|
||||
def convert_label_infor(self, label_infor):
|
||||
"""
|
||||
encode annotations using json.loads
|
||||
:param label_infor: string
|
||||
:return: (image, encoded annotations)
|
||||
"""
|
||||
label_infor = label_infor.decode()
|
||||
label_infor = label_infor.encode('utf-8').decode('utf-8-sig')
|
||||
substr = label_infor.strip("\n").split("\t")
|
||||
|
@ -184,6 +213,11 @@ class DBProcessTest(object):
|
|||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image_type1(self, im):
|
||||
"""
|
||||
resize image to a size self.image_shape
|
||||
:param im: input image
|
||||
:return: normalized image and resize ratio
|
||||
"""
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = im.shape[:2] # (h, w, c)
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
|
@ -192,6 +226,11 @@ class DBProcessTest(object):
|
|||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def normalize(self, im):
|
||||
"""
|
||||
Normalize image
|
||||
:param im: input image
|
||||
:return: Normalized image
|
||||
"""
|
||||
img_mean = [0.485, 0.456, 0.406]
|
||||
img_std = [0.229, 0.224, 0.225]
|
||||
im = im.astype(np.float32, copy=False)
|
||||
|
|
Loading…
Reference in New Issue