improve style text infer process (#2055)
* improve style text * fix dead loop
This commit is contained in:
parent
6a42745f96
commit
d8719969ba
|
@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
|
||||||
self.std = config["Predictor"]["std"]
|
self.std = config["Predictor"]["std"]
|
||||||
self.expand_result = config["Predictor"]["expand_result"]
|
self.expand_result = config["Predictor"]["expand_result"]
|
||||||
|
|
||||||
def predict(self, style_input, text_input):
|
def reshape_to_same_height(self, img_list):
|
||||||
|
h = img_list[0].shape[0]
|
||||||
|
for idx in range(1, len(img_list)):
|
||||||
|
new_w = round(1.0 * img_list[idx].shape[1] /
|
||||||
|
img_list[idx].shape[0] * h)
|
||||||
|
img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
|
||||||
|
return img_list
|
||||||
|
|
||||||
|
def predict_single_image(self, style_input, text_input):
|
||||||
style_input = self.rep_style_input(style_input, text_input)
|
style_input = self.rep_style_input(style_input, text_input)
|
||||||
tensor_style_input = self.preprocess(style_input)
|
tensor_style_input = self.preprocess(style_input)
|
||||||
tensor_text_input = self.preprocess(text_input)
|
tensor_text_input = self.preprocess(text_input)
|
||||||
|
@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
|
||||||
"fake_bg": fake_bg,
|
"fake_bg": fake_bg,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def predict(self, style_input, text_input_list):
|
||||||
|
if not isinstance(text_input_list, (tuple, list)):
|
||||||
|
return self.predict_single_image(style_input, text_input_list)
|
||||||
|
|
||||||
|
synth_result_list = []
|
||||||
|
for text_input in text_input_list:
|
||||||
|
synth_result = self.predict_single_image(style_input, text_input)
|
||||||
|
synth_result_list.append(synth_result)
|
||||||
|
|
||||||
|
for key in synth_result:
|
||||||
|
res = [r[key] for r in synth_result_list]
|
||||||
|
res = self.reshape_to_same_height(res)
|
||||||
|
synth_result[key] = np.concatenate(res, axis=1)
|
||||||
|
return synth_result
|
||||||
|
|
||||||
def preprocess(self, img):
|
def preprocess(self, img):
|
||||||
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
img_height, img_width, channel = img.shape
|
img_height, img_width, channel = img.shape
|
||||||
|
|
|
@ -12,6 +12,8 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
from utils.config import ArgsParser, load_config, override_config
|
from utils.config import ArgsParser, load_config, override_config
|
||||||
from utils.logging import get_logger
|
from utils.logging import get_logger
|
||||||
|
@ -36,8 +38,9 @@ class ImageSynthesiser(object):
|
||||||
self.predictor = getattr(predictors, predictor_method)(self.config)
|
self.predictor = getattr(predictors, predictor_method)(self.config)
|
||||||
|
|
||||||
def synth_image(self, corpus, style_input, language="en"):
|
def synth_image(self, corpus, style_input, language="en"):
|
||||||
corpus, text_input = self.text_drawer.draw_text(corpus, language)
|
corpus_list, text_input_list = self.text_drawer.draw_text(
|
||||||
synth_result = self.predictor.predict(style_input, text_input)
|
corpus, language, style_input_width=style_input.shape[1])
|
||||||
|
synth_result = self.predictor.predict(style_input, text_input_list)
|
||||||
return synth_result
|
return synth_result
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
|
||||||
for i in range(self.output_num):
|
for i in range(self.output_num):
|
||||||
style_data = self.style_sampler.sample()
|
style_data = self.style_sampler.sample()
|
||||||
style_input = style_data["image"]
|
style_input = style_data["image"]
|
||||||
corpus_language, text_input_label = self.corpus_generator.generate(
|
corpus_language, text_input_label = self.corpus_generator.generate()
|
||||||
)
|
text_input_label_list, text_input_list = self.text_drawer.draw_text(
|
||||||
text_input_label, text_input = self.text_drawer.draw_text(
|
text_input_label,
|
||||||
text_input_label, corpus_language)
|
corpus_language,
|
||||||
|
style_input_width=style_input.shape[1])
|
||||||
|
|
||||||
synth_result = self.predictor.predict(style_input, text_input)
|
text_input_label = "".join(text_input_label_list)
|
||||||
|
|
||||||
|
synth_result = self.predictor.predict(style_input, text_input_list)
|
||||||
fake_fusion = synth_result["fake_fusion"]
|
fake_fusion = synth_result["fake_fusion"]
|
||||||
self.writer.save_image(fake_fusion, text_input_label)
|
self.writer.save_image(fake_fusion, text_input_label)
|
||||||
self.writer.save_label()
|
self.writer.save_label()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import cv2
|
||||||
from utils.logging import get_logger
|
from utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@ -28,7 +29,11 @@ class StdTextDrawer(object):
|
||||||
else:
|
else:
|
||||||
return int((self.height - 4)**2 / font_height)
|
return int((self.height - 4)**2 / font_height)
|
||||||
|
|
||||||
def draw_text(self, corpus, language="en", crop=True):
|
def draw_text(self,
|
||||||
|
corpus,
|
||||||
|
language="en",
|
||||||
|
crop=True,
|
||||||
|
style_input_width=None):
|
||||||
if language not in self.support_languages:
|
if language not in self.support_languages:
|
||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
"language {} not supported, use en instead.".format(language))
|
"language {} not supported, use en instead.".format(language))
|
||||||
|
@ -37,21 +42,43 @@ class StdTextDrawer(object):
|
||||||
width = min(self.max_width, len(corpus) * self.height) + 4
|
width = min(self.max_width, len(corpus) * self.height) + 4
|
||||||
else:
|
else:
|
||||||
width = len(corpus) * self.height + 4
|
width = len(corpus) * self.height + 4
|
||||||
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
|
||||||
draw = ImageDraw.Draw(bg)
|
|
||||||
|
|
||||||
char_x = 2
|
if style_input_width is not None:
|
||||||
font = self.font_dict[language]
|
width = min(width, style_input_width)
|
||||||
for i, char_i in enumerate(corpus):
|
|
||||||
char_size = font.getsize(char_i)[0]
|
corpus_list = []
|
||||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
text_input_list = []
|
||||||
char_x += char_size
|
|
||||||
if char_x >= width:
|
while len(corpus) != 0:
|
||||||
corpus = corpus[0:i + 1]
|
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
||||||
self.logger.warning("corpus length exceed limit: {}".format(
|
draw = ImageDraw.Draw(bg)
|
||||||
corpus))
|
char_x = 2
|
||||||
|
font = self.font_dict[language]
|
||||||
|
i = 0
|
||||||
|
while i < len(corpus):
|
||||||
|
char_i = corpus[i]
|
||||||
|
char_size = font.getsize(char_i)[0]
|
||||||
|
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
|
||||||
|
if char_x + char_size >= width and i != 0:
|
||||||
|
text_input = np.array(bg).astype(np.uint8)
|
||||||
|
text_input = text_input[:, 0:char_x, :]
|
||||||
|
|
||||||
|
corpus_list.append(corpus[0:i])
|
||||||
|
text_input_list.append(text_input)
|
||||||
|
corpus = corpus[i:]
|
||||||
|
break
|
||||||
|
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||||
|
char_x += char_size
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
# the whole text is shorter than style input
|
||||||
|
if i == len(corpus):
|
||||||
|
text_input = np.array(bg).astype(np.uint8)
|
||||||
|
text_input = text_input[:, 0:char_x, :]
|
||||||
|
|
||||||
|
corpus_list.append(corpus[0:i])
|
||||||
|
text_input_list.append(text_input)
|
||||||
|
corpus = corpus[i:]
|
||||||
break
|
break
|
||||||
|
|
||||||
text_input = np.array(bg).astype(np.uint8)
|
return corpus_list, text_input_list
|
||||||
text_input = text_input[:, 0:char_x, :]
|
|
||||||
return corpus, text_input
|
|
||||||
|
|
Loading…
Reference in New Issue