commit
115b517584
|
@ -19,6 +19,8 @@ import random
|
|||
from ppocr.utils.utility import initial_logger
|
||||
logger = initial_logger()
|
||||
|
||||
from .text_image_aug.augment import tia_distort, tia_stretch, tia_perspective
|
||||
|
||||
|
||||
def get_bounding_box_rect(pos):
|
||||
left = min(pos[0])
|
||||
|
@ -196,6 +198,9 @@ class Config:
|
|||
self.h = h
|
||||
|
||||
self.perspective = True
|
||||
self.stretch = True
|
||||
self.distort = True
|
||||
|
||||
self.crop = True
|
||||
self.affine = False
|
||||
self.reverse = True
|
||||
|
@ -299,41 +304,40 @@ def warp(img, ang):
|
|||
config.make(w, h, ang)
|
||||
new_img = img
|
||||
|
||||
prob = 0.4
|
||||
|
||||
if config.distort:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_distort(new_img, random.randint(3, 6))
|
||||
|
||||
if config.stretch:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = tia_stretch(new_img, random.randint(3, 6))
|
||||
|
||||
if config.perspective:
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50:
|
||||
warpR, (r1, c1), ratio, dst = get_warpR(config)
|
||||
new_w = int(np.max(dst[:, 0])) - int(np.min(dst[:, 0]))
|
||||
new_img = cv2.warpPerspective(
|
||||
new_img,
|
||||
warpR, (int(new_w * ratio), h),
|
||||
borderMode=config.borderMode)
|
||||
if random.random() <= prob:
|
||||
new_img = tia_perspective(new_img)
|
||||
|
||||
if config.crop:
|
||||
img_height, img_width = img.shape[0:2]
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50 and img_height >= 20 and img_width >= 20:
|
||||
if random.random() <= prob and img_height >= 20 and img_width >= 20:
|
||||
new_img = get_crop(new_img)
|
||||
if config.affine:
|
||||
warpT = get_warpAffine(config)
|
||||
new_img = cv2.warpAffine(
|
||||
new_img, warpT, (w, h), borderMode=config.borderMode)
|
||||
|
||||
if config.blur:
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50:
|
||||
if random.random() <= prob:
|
||||
new_img = blur(new_img)
|
||||
if config.color:
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50:
|
||||
if random.random() <= prob:
|
||||
new_img = cvtColor(new_img)
|
||||
if config.jitter:
|
||||
new_img = jitter(new_img)
|
||||
if config.noise:
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50:
|
||||
if random.random() <= prob:
|
||||
new_img = add_gasuss_noise(new_img)
|
||||
if config.reverse:
|
||||
tp = random.randint(1, 100)
|
||||
if tp >= 50:
|
||||
if random.random() <= prob:
|
||||
new_img = 255 - new_img
|
||||
return new_img
|
||||
|
||||
|
@ -382,6 +386,7 @@ def process_image(img,
|
|||
% loss_type
|
||||
return (norm_img)
|
||||
|
||||
|
||||
def resize_norm_img_srn(img, image_shape):
|
||||
imgC, imgH, imgW = image_shape
|
||||
|
||||
|
@ -408,30 +413,39 @@ def resize_norm_img_srn(img, image_shape):
|
|||
|
||||
return np.reshape(img_black, (c, row, col)).astype(np.float32)
|
||||
|
||||
def srn_other_inputs(image_shape,
|
||||
num_heads,
|
||||
max_text_length,
|
||||
char_num):
|
||||
|
||||
def srn_other_inputs(image_shape, num_heads, max_text_length, char_num):
|
||||
|
||||
imgC, imgH, imgW = image_shape
|
||||
feature_dim = int((imgH / 8) * (imgW / 8))
|
||||
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape((feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape((max_text_length, 1)).astype('int64')
|
||||
encoder_word_pos = np.array(range(0, feature_dim)).reshape(
|
||||
(feature_dim, 1)).astype('int64')
|
||||
gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
|
||||
(max_text_length, 1)).astype('int64')
|
||||
|
||||
lbl_weight = np.array([int(char_num-1)] * max_text_length).reshape((-1,1)).astype('int64')
|
||||
lbl_weight = np.array([int(char_num - 1)] * max_text_length).reshape(
|
||||
(-1, 1)).astype('int64')
|
||||
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, [1, num_heads, 1, 1]) * [-1e9]
|
||||
gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
|
||||
gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
|
||||
[1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape([-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, [1, num_heads, 1, 1]) * [-1e9]
|
||||
gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
|
||||
[-1, 1, max_text_length, max_text_length])
|
||||
gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
|
||||
[1, num_heads, 1, 1]) * [-1e9]
|
||||
|
||||
encoder_word_pos = encoder_word_pos[np.newaxis, :]
|
||||
gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
|
||||
|
||||
return [lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2]
|
||||
return [
|
||||
lbl_weight, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2
|
||||
]
|
||||
|
||||
|
||||
def process_image_srn(img,
|
||||
image_shape,
|
||||
|
@ -453,14 +467,16 @@ def process_image_srn(img,
|
|||
return None
|
||||
else:
|
||||
if loss_type == "srn":
|
||||
text_padded = [int(char_num-1)] * max_text_length
|
||||
text_padded = [int(char_num - 1)] * max_text_length
|
||||
for i in range(len(text)):
|
||||
text_padded[i] = text[i]
|
||||
lbl_weight[i] = [1.0]
|
||||
text_padded = np.array(text_padded)
|
||||
text = text_padded.reshape(-1, 1)
|
||||
return (norm_img, text,encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2,lbl_weight)
|
||||
return (norm_img, text, encoder_word_pos, gsrm_word_pos,
|
||||
gsrm_slf_attn_bias1, gsrm_slf_attn_bias2, lbl_weight)
|
||||
else:
|
||||
assert False, "Unsupport loss_type %s in process_image"\
|
||||
% loss_type
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2)
|
||||
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
|
||||
gsrm_slf_attn_bias2)
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
# Author: RubanSeven
|
||||
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
|
||||
|
||||
# import cv2
|
||||
import numpy as np
|
||||
from .warp_mls import WarpMLS
|
||||
|
||||
|
||||
def tia_distort(src, segment=4):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut // 3
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([np.random.randint(thresh), np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[img_w - np.random.randint(thresh), np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[img_w - np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
||||
dst_pts.append(
|
||||
[np.random.randint(thresh), img_h - np.random.randint(thresh)])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
dst_pts.append([
|
||||
cut * cut_idx + np.random.randint(thresh) - half_thresh,
|
||||
img_h + np.random.randint(thresh) - half_thresh
|
||||
])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def tia_stretch(src, segment=4):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
cut = img_w // segment
|
||||
thresh = cut * 4 // 5
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, 0])
|
||||
dst_pts.append([img_w, 0])
|
||||
dst_pts.append([img_w, img_h])
|
||||
dst_pts.append([0, img_h])
|
||||
|
||||
half_thresh = thresh * 0.5
|
||||
|
||||
for cut_idx in np.arange(1, segment, 1):
|
||||
move = np.random.randint(thresh) - half_thresh
|
||||
src_pts.append([cut * cut_idx, 0])
|
||||
src_pts.append([cut * cut_idx, img_h])
|
||||
dst_pts.append([cut * cut_idx + move, 0])
|
||||
dst_pts.append([cut * cut_idx + move, img_h])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def tia_perspective(src):
|
||||
img_h, img_w = src.shape[:2]
|
||||
|
||||
thresh = img_h // 2
|
||||
|
||||
src_pts = list()
|
||||
dst_pts = list()
|
||||
|
||||
src_pts.append([0, 0])
|
||||
src_pts.append([img_w, 0])
|
||||
src_pts.append([img_w, img_h])
|
||||
src_pts.append([0, img_h])
|
||||
|
||||
dst_pts.append([0, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, np.random.randint(thresh)])
|
||||
dst_pts.append([img_w, img_h - np.random.randint(thresh)])
|
||||
dst_pts.append([0, img_h - np.random.randint(thresh)])
|
||||
|
||||
trans = WarpMLS(src, src_pts, dst_pts, img_w, img_h)
|
||||
dst = trans.generate()
|
||||
|
||||
return dst
|
|
@ -0,0 +1,154 @@
|
|||
# -*- coding:utf-8 -*-
|
||||
# Author: RubanSeven
|
||||
# Reference: https://github.com/RubanSeven/Text-Image-Augmentation-python
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
class WarpMLS:
|
||||
def __init__(self, src, src_pts, dst_pts, dst_w, dst_h, trans_ratio=1.):
|
||||
self.src = src
|
||||
self.src_pts = src_pts
|
||||
self.dst_pts = dst_pts
|
||||
self.pt_count = len(self.dst_pts)
|
||||
self.dst_w = dst_w
|
||||
self.dst_h = dst_h
|
||||
self.trans_ratio = trans_ratio
|
||||
self.grid_size = 100
|
||||
self.rdx = np.zeros((self.dst_h, self.dst_w))
|
||||
self.rdy = np.zeros((self.dst_h, self.dst_w))
|
||||
|
||||
@staticmethod
|
||||
def __bilinear_interp(x, y, v11, v12, v21, v22):
|
||||
return (v11 * (1 - y) + v12 * y) * (1 - x) + (v21 *
|
||||
(1 - y) + v22 * y) * x
|
||||
|
||||
def generate(self):
|
||||
self.calc_delta()
|
||||
return self.gen_img()
|
||||
|
||||
def calc_delta(self):
|
||||
w = np.zeros(self.pt_count, dtype=np.float32)
|
||||
|
||||
if self.pt_count < 2:
|
||||
return
|
||||
|
||||
i = 0
|
||||
while 1:
|
||||
if self.dst_w <= i < self.dst_w + self.grid_size - 1:
|
||||
i = self.dst_w - 1
|
||||
elif i >= self.dst_w:
|
||||
break
|
||||
|
||||
j = 0
|
||||
while 1:
|
||||
if self.dst_h <= j < self.dst_h + self.grid_size - 1:
|
||||
j = self.dst_h - 1
|
||||
elif j >= self.dst_h:
|
||||
break
|
||||
|
||||
sw = 0
|
||||
swp = np.zeros(2, dtype=np.float32)
|
||||
swq = np.zeros(2, dtype=np.float32)
|
||||
new_pt = np.zeros(2, dtype=np.float32)
|
||||
cur_pt = np.array([i, j], dtype=np.float32)
|
||||
|
||||
k = 0
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
break
|
||||
|
||||
w[k] = 1. / (
|
||||
(i - self.dst_pts[k][0]) * (i - self.dst_pts[k][0]) +
|
||||
(j - self.dst_pts[k][1]) * (j - self.dst_pts[k][1]))
|
||||
|
||||
sw += w[k]
|
||||
swp = swp + w[k] * np.array(self.dst_pts[k])
|
||||
swq = swq + w[k] * np.array(self.src_pts[k])
|
||||
|
||||
if k == self.pt_count - 1:
|
||||
pstar = 1 / sw * swp
|
||||
qstar = 1 / sw * swq
|
||||
|
||||
miu_s = 0
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
continue
|
||||
pt_i = self.dst_pts[k] - pstar
|
||||
miu_s += w[k] * np.sum(pt_i * pt_i)
|
||||
|
||||
cur_pt -= pstar
|
||||
cur_pt_j = np.array([-cur_pt[1], cur_pt[0]])
|
||||
|
||||
for k in range(self.pt_count):
|
||||
if i == self.dst_pts[k][0] and j == self.dst_pts[k][1]:
|
||||
continue
|
||||
|
||||
pt_i = self.dst_pts[k] - pstar
|
||||
pt_j = np.array([-pt_i[1], pt_i[0]])
|
||||
|
||||
tmp_pt = np.zeros(2, dtype=np.float32)
|
||||
tmp_pt[0] = np.sum(pt_i * cur_pt) * self.src_pts[k][0] - \
|
||||
np.sum(pt_j * cur_pt) * self.src_pts[k][1]
|
||||
tmp_pt[1] = -np.sum(pt_i * cur_pt_j) * self.src_pts[k][0] + \
|
||||
np.sum(pt_j * cur_pt_j) * self.src_pts[k][1]
|
||||
tmp_pt *= (w[k] / miu_s)
|
||||
new_pt += tmp_pt
|
||||
|
||||
new_pt += qstar
|
||||
else:
|
||||
new_pt = self.src_pts[k]
|
||||
|
||||
self.rdx[j, i] = new_pt[0] - i
|
||||
self.rdy[j, i] = new_pt[1] - j
|
||||
|
||||
j += self.grid_size
|
||||
i += self.grid_size
|
||||
|
||||
def gen_img(self):
|
||||
src_h, src_w = self.src.shape[:2]
|
||||
dst = np.zeros_like(self.src, dtype=np.float32)
|
||||
|
||||
for i in np.arange(0, self.dst_h, self.grid_size):
|
||||
for j in np.arange(0, self.dst_w, self.grid_size):
|
||||
ni = i + self.grid_size
|
||||
nj = j + self.grid_size
|
||||
w = h = self.grid_size
|
||||
if ni >= self.dst_h:
|
||||
ni = self.dst_h - 1
|
||||
h = ni - i + 1
|
||||
if nj >= self.dst_w:
|
||||
nj = self.dst_w - 1
|
||||
w = nj - j + 1
|
||||
|
||||
di = np.reshape(np.arange(h), (-1, 1))
|
||||
dj = np.reshape(np.arange(w), (1, -1))
|
||||
delta_x = self.__bilinear_interp(
|
||||
di / h, dj / w, self.rdx[i, j], self.rdx[i, nj],
|
||||
self.rdx[ni, j], self.rdx[ni, nj])
|
||||
delta_y = self.__bilinear_interp(
|
||||
di / h, dj / w, self.rdy[i, j], self.rdy[i, nj],
|
||||
self.rdy[ni, j], self.rdy[ni, nj])
|
||||
nx = j + dj + delta_x * self.trans_ratio
|
||||
ny = i + di + delta_y * self.trans_ratio
|
||||
nx = np.clip(nx, 0, src_w - 1)
|
||||
ny = np.clip(ny, 0, src_h - 1)
|
||||
nxi = np.array(np.floor(nx), dtype=np.int32)
|
||||
nyi = np.array(np.floor(ny), dtype=np.int32)
|
||||
nxi1 = np.array(np.ceil(nx), dtype=np.int32)
|
||||
nyi1 = np.array(np.ceil(ny), dtype=np.int32)
|
||||
|
||||
if len(self.src.shape) == 3:
|
||||
x = np.tile(np.expand_dims(ny - nyi, axis=-1), (1, 1, 3))
|
||||
y = np.tile(np.expand_dims(nx - nxi, axis=-1), (1, 1, 3))
|
||||
else:
|
||||
x = ny - nyi
|
||||
y = nx - nxi
|
||||
dst[i:i + h, j:j + w] = self.__bilinear_interp(
|
||||
x, y, self.src[nyi, nxi], self.src[nyi, nxi1],
|
||||
self.src[nyi1, nxi], self.src[nyi1, nxi1])
|
||||
|
||||
dst = np.clip(dst, 0, 255)
|
||||
dst = np.array(dst, dtype=np.uint8)
|
||||
|
||||
return dst
|
Loading…
Reference in New Issue