diff --git a/ppocr/data/rec/img_tools.py b/ppocr/data/rec/img_tools.py index 8b497e6b..4d51f7ee 100755 --- a/ppocr/data/rec/img_tools.py +++ b/ppocr/data/rec/img_tools.py @@ -19,6 +19,8 @@ import random from ppocr.utils.utility import initial_logger logger = initial_logger() +from .text_image_aug.augment import tia_distort, tia_stretch, tia_perspective + def get_bounding_box_rect(pos): left = min(pos[0]) @@ -196,6 +198,9 @@ class Config: self.h = h self.perspective = True + self.stretch = True + self.distort = True + self.crop = True self.affine = False self.reverse = True @@ -299,41 +304,40 @@ def warp(img, ang): config.make(w, h, ang) new_img = img + prob = 0.4 + + if config.distort: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_distort(new_img, random.randint(3, 6)) + + if config.stretch: + img_height, img_width = img.shape[0:2] + if random.random() <= prob and img_height >= 20 and img_width >= 20: + new_img = tia_stretch(new_img, random.randint(3, 6)) + if config.perspective: - tp = random.randint(1, 100) - if tp >= 50: - warpR, (r1, c1), ratio, dst = get_warpR(config) - new_w = int(np.max(dst[:, 0])) - int(np.min(dst[:, 0])) - new_img = cv2.warpPerspective( - new_img, - warpR, (int(new_w * ratio), h), - borderMode=config.borderMode) + if random.random() <= prob: + new_img = tia_perspective(new_img) + if config.crop: img_height, img_width = img.shape[0:2] - tp = random.randint(1, 100) - if tp >= 50 and img_height >= 20 and img_width >= 20: + if random.random() <= prob and img_height >= 20 and img_width >= 20: new_img = get_crop(new_img) - if config.affine: - warpT = get_warpAffine(config) - new_img = cv2.warpAffine( - new_img, warpT, (w, h), borderMode=config.borderMode) + if config.blur: - tp = random.randint(1, 100) - if tp >= 50: + if random.random() <= prob: new_img = blur(new_img) if config.color: - tp = random.randint(1, 100) - if tp >= 50: + if random.random() <= prob: new_img = cvtColor(new_img) if config.jitter: new_img = jitter(new_img) if config.noise: - tp = random.randint(1, 100) - if tp >= 50: + if random.random() <= prob: new_img = add_gasuss_noise(new_img) if config.reverse: - tp = random.randint(1, 100) - if tp >= 50: + if random.random() <= prob: new_img = 255 - new_img return new_img @@ -360,7 +364,7 @@ def process_image(img, text = char_ops.encode(label) if len(text) == 0 or len(text) > max_text_length: logger.info( - "Warning in ppocr/data/rec/img_tools.py: Wrong data type." + "Warning in ppocr/data/rec/img_tools.py:line362: Wrong data type." "Excepted string with length between 1 and {}, but " "got '{}'. Label is '{}'".format(max_text_length, len(text), label)) @@ -382,6 +386,7 @@ def process_image(img, % loss_type return (norm_img) + def resize_norm_img_srn(img, image_shape): imgC, imgH, imgW = image_shape @@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape): return np.reshape(img_black, (c, row, col)).astype(np.float32) -def srn_other_inputs(image_shape, - num_heads, - max_text_length, - char_num): + +def srn_other_inputs(image_shape, num_heads, max_text_length, char_num): imgC, imgH, imgW = image_shape feature_dim = int((imgH / 8) * (imgW / 8)) - encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64') - gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64') + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') - lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64') + lbl_weight = np.array([int(char_num - 1)] * max_text_length).reshape( + (-1, 1)).astype('int64') - gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) - gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9] + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]) * [-1e9] - gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length]) - gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9] + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]) * [-1e9] encoder_word_pos = encoder_word_pos[np.newaxis, :] gsrm_word_pos = gsrm_word_pos[np.newaxis, :] - return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] + return [ + lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + def process_image_srn(img, image_shape, @@ -453,14 +467,16 @@ def process_image_srn(img, return None else: if loss_type == "srn": - text_padded = [int(char_num-1)] * max_text_length + text_padded = [int(char_num - 1)] * max_text_length for i in range(len(text)): text_padded[i] = text[i] lbl_weight[i] = [1.0] text_padded = np.array(text_padded) text = text_padded.reshape(-1, 1) - return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight) + return (norm_img, text, encoder_word_pos, gsrm_word_pos, + gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight) else: assert False, "Unsupport loss_type %s in process_image"\ % loss_type - return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2) + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) diff --git a/ppocr/data/rec/text_image_aug/augment.py b/ppocr/data/rec/text_image_aug/augment.py new file mode 100644 index 00000000..10ffbbd1 --- /dev/null +++ b/ppocr/data/rec/text_image_aug/augment.py @@ -0,0 +1,107 @@ +# -*- coding:utf-8 -*- +# Author: RubanSeven +# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python + +# import cv2 +import numpy as np +from .warp_mls import WarpMLS + + +def tia_distort(src, segment=4): + img_h, img_w = src.shape[:2] + + cut = img_w // segment + thresh = cut // 3 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)]) + dst_pts.append( + [img_w - np.random.randint(thresh), np.random.randint(thresh)]) + dst_pts.append( + [img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)]) + dst_pts.append( + [np.random.randint(thresh), img_h - np.random.randint(thresh)]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append([ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + np.random.randint(thresh) - half_thresh + ]) + dst_pts.append([ + cut * cut_idx + np.random.randint(thresh) - half_thresh, + img_h + np.random.randint(thresh) - half_thresh + ]) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst + + +def tia_stretch(src, segment=4): + img_h, img_w = src.shape[:2] + + cut = img_w // segment + thresh = cut * 4 // 5 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, 0]) + dst_pts.append([img_w, 0]) + dst_pts.append([img_w, img_h]) + dst_pts.append([0, img_h]) + + half_thresh = thresh * 0.5 + + for cut_idx in np.arange(1, segment, 1): + move = np.random.randint(thresh) - half_thresh + src_pts.append([cut * cut_idx, 0]) + src_pts.append([cut * cut_idx, img_h]) + dst_pts.append([cut * cut_idx + move, 0]) + dst_pts.append([cut * cut_idx + move, img_h]) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst + + +def tia_perspective(src): + img_h, img_w = src.shape[:2] + + thresh = img_h // 2 + + src_pts = list() + dst_pts = list() + + src_pts.append([0, 0]) + src_pts.append([img_w, 0]) + src_pts.append([img_w, img_h]) + src_pts.append([0, img_h]) + + dst_pts.append([0, np.random.randint(thresh)]) + dst_pts.append([img_w, np.random.randint(thresh)]) + dst_pts.append([img_w, img_h - np.random.randint(thresh)]) + dst_pts.append([0, img_h - np.random.randint(thresh)]) + + trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h) + dst = trans.generate() + + return dst diff --git a/ppocr/data/rec/text_image_aug/warp_mls.py b/ppocr/data/rec/text_image_aug/warp_mls.py new file mode 100644 index 00000000..8994fac0 --- /dev/null +++ b/ppocr/data/rec/text_image_aug/warp_mls.py @@ -0,0 +1,154 @@ +# -*- coding:utf-8 -*- +# Author: RubanSeven +# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python +import math +import numpy as np + + +class WarpMLS: + def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.): + self.src = src + self.src_pts = src_pts + self.dst_pts = dst_pts + self.pt_count = len(self.dst_pts) + self.dst_w = dst_w + self.dst_h = dst_h + self.trans_ratio = trans_ratio + self.grid_size = 100 + self.rdx = np.zeros((self.dst_h, self.dst_w)) + self.rdy = np.zeros((self.dst_h, self.dst_w)) + + @staticmethod + def __bilinear_interp(x, y, v11, v12, v21, v22): + return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 * + (1 - y) + v22 * y) * x + + def generate(self): + self.calc_delta() + return self.gen_img() + + def calc_delta(self): + w = np.zeros(self.pt_count, dtype=np.float32) + + if self.pt_count < 2: + return + + i = 0 + while 1: + if self.dst_w <= i < self.dst_w + self.grid_size - 1: + i = self.dst_w - 1 + elif i >= self.dst_w: + break + + j = 0 + while 1: + if self.dst_h <= j < self.dst_h + self.grid_size - 1: + j = self.dst_h - 1 + elif j >= self.dst_h: + break + + sw = 0 + swp = np.zeros(2, dtype=np.float32) + swq = np.zeros(2, dtype=np.float32) + new_pt = np.zeros(2, dtype=np.float32) + cur_pt = np.array([i, j], dtype=np.float32) + + k = 0 + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + break + + w[k] = 1. / ( + (i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) + + (j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1])) + + sw += w[k] + swp = swp + w[k] * np.array(self.dst_pts[k]) + swq = swq + w[k] * np.array(self.src_pts[k]) + + if k == self.pt_count - 1: + pstar = 1 / sw * swp + qstar = 1 / sw * swq + + miu_s = 0 + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + continue + pt_i = self.dst_pts[k] - pstar + miu_s += w[k] * np.sum(pt_i * pt_i) + + cur_pt -= pstar + cur_pt_j = np.array([-cur_pt[1], cur_pt[0]]) + + for k in range(self.pt_count): + if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]: + continue + + pt_i = self.dst_pts[k] - pstar + pt_j = np.array([-pt_i[1], pt_i[0]]) + + tmp_pt = np.zeros(2, dtype=np.float32) + tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \ + np.sum(pt_j * cur_pt) * self.src_pts[k][1] + tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \ + np.sum(pt_j * cur_pt_j) * self.src_pts[k][1] + tmp_pt *= (w[k] / miu_s) + new_pt += tmp_pt + + new_pt += qstar + else: + new_pt = self.src_pts[k] + + self.rdx[j, i] = new_pt[0] - i + self.rdy[j, i] = new_pt[1] - j + + j += self.grid_size + i += self.grid_size + + def gen_img(self): + src_h, src_w = self.src.shape[:2] + dst = np.zeros_like(self.src, dtype=np.float32) + + for i in np.arange(0, self.dst_h, self.grid_size): + for j in np.arange(0, self.dst_w, self.grid_size): + ni = i + self.grid_size + nj = j + self.grid_size + w = h = self.grid_size + if ni >= self.dst_h: + ni = self.dst_h - 1 + h = ni - i + 1 + if nj >= self.dst_w: + nj = self.dst_w - 1 + w = nj - j + 1 + + di = np.reshape(np.arange(h), (-1, 1)) + dj = np.reshape(np.arange(w), (1, -1)) + delta_x = self.__bilinear_interp( + di / h, dj / w, self.rdx[i, j], self.rdx[i, nj], + self.rdx[ni, j], self.rdx[ni, nj]) + delta_y = self.__bilinear_interp( + di / h, dj / w, self.rdy[i, j], self.rdy[i, nj], + self.rdy[ni, j], self.rdy[ni, nj]) + nx = j + dj + delta_x * self.trans_ratio + ny = i + di + delta_y * self.trans_ratio + nx = np.clip(nx, 0, src_w - 1) + ny = np.clip(ny, 0, src_h - 1) + nxi = np.array(np.floor(nx), dtype=np.int32) + nyi = np.array(np.floor(ny), dtype=np.int32) + nxi1 = np.array(np.ceil(nx), dtype=np.int32) + nyi1 = np.array(np.ceil(ny), dtype=np.int32) + + if len(self.src.shape) == 3: + x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3)) + y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3)) + else: + x = ny - nyi + y = nx - nxi + dst[i:i + h, j:j + w] = self.__bilinear_interp( + x, y, self.src[nyi, nxi], self.src[nyi, nxi1], + self.src[nyi1, nxi], self.src[nyi1, nxi1]) + + dst = np.clip(dst, 0, 255) + dst = np.array(dst, dtype=np.uint8) + + return dst