diff --git a/doc/doc_ch/style_text_rec.md b/doc/doc_ch/style_text_rec.md
new file mode 100644
index 00000000..3a6d8b8b
--- /dev/null
+++ b/doc/doc_ch/style_text_rec.md
@@ -0,0 +1,150 @@
+## Style Text Rec
+
+### 目录
+[工具简介](#工具简介)
+[环境配置](#环境配置)
+[快速上手](#快速上手)
+[高级使用](#高级使用)
+[应用示例](#应用示例)
+
+### 工具简介
+
+

+
+
+Style-Text是对百度自研文本编辑算法《Editing Text in the Wild》中提出的SRNet网络的改进,不同于常用的GAN的方法只选择一个分支,该工具将文本合成任务分解为三个子模块,文本风格迁移模块、背景抽取模块和前背景融合模块,来提升合成数据的效果。下图显示了一些示例结果。
+
+
+

+

+
+
+此外,在实际铭牌文本识别场景和韩语文本识别场景,验证了该合成工具的有效性。
+
+### 环境配置
+
+1. 参考[快速安装](./installation.md),安装PaddleOCR。强烈建议您使用python3环境。
+2. 进入`style_text_rec`目录,下载模型,并解压:
+
+```bash
+cd style_text_rec
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
+unzip style_text_models.zip
+```
+
+如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
+
+```
+bg_generator:
+ pretrain: style_text_models/bg_generator
+...
+text_generator:
+ pretrain: style_text_models/text_generator
+...
+fusion_generator:
+ pretrain: style_text_models/fusion_generator
+```
+
+### 快速上手
+
+1. 运行tools/synth_image,生成示例图片:
+
+```python
+python3 -m tools.synth_image -c configs/config.yml
+```
+
+1. 运行后,会生成`fake_busion.jpg`,即为最终结果。
+
+

+
+除此之外,程序还会生成并保存中间结果:
+ * `fake_bg.jpg`:为风格参考图去掉文字后的背景;
+ * `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
+
+2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数:
+```python
+python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
+```
+ * 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。
+
+3. 在`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。
+
+### 高级使用
+
+在开始合成数据集前,需要准备一些素材。
+
+首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片.
+
+1. 在`configs/dataset_config.yml`中配置输入数据路径。
+ * `StyleSampler`:
+ * `method`:使用的风格图片采样方法;
+ * `image_home`:风格图片目录;
+ * `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
+ * `with_label`:标志`label_file`是否为label文件。
+ 我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用。
+ * `CorpusGenerator`:
+ * `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
+ * `language`:语料的语种;
+ * `corpus_file`: 语料文件路径。
+
+2. 运行`tools/synth_dataset`合成数据:
+
+ ``` bash
+ python -m tools.synth_dataset -c configs/dataset_config.yml
+ ```
+
+3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag`(`-t`),如下所示:
+
+ ```bash
+ python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml
+ python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml
+ ```
+
+
+### 应用示例
+
+在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例:
+
+接下来请参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。
+
+### 项目结构
+```
+style_text_rec
+|-- arch
+| |-- base_module.py
+| |-- decoder.py
+| |-- encoder.py
+| |-- spectral_norm.py
+| `-- style_text_rec.py
+|-- configs
+| |-- config.yml
+| `-- dataset_config.yml
+|-- engine
+| |-- corpus_generators.py
+| |-- predictors.py
+| |-- style_samplers.py
+| |-- synthesisers.py
+| |-- text_drawers.py
+| `-- writers.py
+|-- examples
+| |-- corpus
+| | `-- example.txt
+| |-- image_list.txt
+| `-- style_images
+| |-- 1.jpg
+| `-- 2.jpg
+|-- fonts
+| |-- ch_standard.ttf
+| |-- en_standard.ttf
+| `-- ko_standard.ttf
+|-- tools
+| |-- __init__.py
+| |-- synth_dataset.py
+| `-- synth_image.py
+`-- utils
+ |-- config.py
+ |-- load_params.py
+ |-- logging.py
+ |-- math_functions.py
+ `-- sys_funcs.py
+```
\ No newline at end of file
diff --git a/doc/doc_en/style_text_rec_en.md b/doc/doc_en/style_text_rec_en.md
new file mode 100644
index 00000000..7e7d29c9
--- /dev/null
+++ b/doc/doc_en/style_text_rec_en.md
@@ -0,0 +1,106 @@
+### Quick Start
+
+`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module.
+
+The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows.
+
+
+#### Preparation
+
+1. Please refer the [QUICK INSTALLATION](./installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
+2. Download the pretrained models and unzip:
+
+```bash
+cd tools/style_text_rec
+wget /path/to/style_text_models.zip
+unzip style_text_models.zip
+```
+
+You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`:
+
+```
+bg_generator:
+ pretrain: style_text_rec/bg_generator
+...
+text_generator:
+ pretrain: style_text_models/text_generator
+...
+fusion_generator:
+ pretrain: style_text_models/fusion_generator
+```
+
+
+
+#### Demo
+
+1. You can use the following commands to run a demo:
+
+```bash
+python -m tools.synth_image -c configs/config.yml
+```
+
+2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them:
+ * `fake_text.jpg` is the generated image with the same font style as `Style Input`;
+ * `fake_bg.jpg` is the generated image of `Style Input` after removing foreground.
+ * `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`.
+
+3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`:
+ * `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`;
+ * `corpus = "PaddleOCR"`: the `Text Input`;
+ * Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`.
+
+4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data.
+
+### Advanced Usage
+
+#### Components
+
+`Style Text Rec` mainly contains the following components:
+
+* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`.
+
+* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`:
+ * `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces.
+ * `FileCorpus`: It can read a text file and randomly return the words in it.
+
+* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus.
+
+* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`.
+
+* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk.
+
+* `synthesisers`: It can call the all modules to complete the work.
+
+### Generate Dataset
+
+Before the start, you need to prepare some data as material.
+First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks.
+
+1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
+ * `StyleSampler`:
+ * `method`: The method of `StyleSampler`.
+ * `image_home`: The directory of pictures.
+ * `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path.
+ * `with_label`: The `label_file` is label file or not.
+
+ * `CorpusGenerator`:
+ * `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed.
+ * `language`: The language of the corpus. Needed if method is not `EnNumCorpus`.
+ * `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`.
+
+2. You can run the following command to start synthesis task:
+
+ ``` bash
+ python -m tools.synth_dataset.py -c configs/dataset_config.yml
+ ```
+
+3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`:
+
+ ```bash
+ python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml
+ python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml
+ ```
+
+### OCR Recognition Training
+
+After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83).
\ No newline at end of file
diff --git a/doc/imgs_style_text/1.png b/doc/imgs_style_text/1.png
new file mode 100644
index 00000000..8f7574ba
Binary files /dev/null and b/doc/imgs_style_text/1.png differ
diff --git a/doc/imgs_style_text/2.png b/doc/imgs_style_text/2.png
new file mode 100644
index 00000000..ce9bf471
Binary files /dev/null and b/doc/imgs_style_text/2.png differ
diff --git a/doc/imgs_style_text/3.png b/doc/imgs_style_text/3.png
new file mode 100644
index 00000000..0fb73a31
Binary files /dev/null and b/doc/imgs_style_text/3.png differ
diff --git a/doc/imgs_style_text/4.jpg b/doc/imgs_style_text/4.jpg
new file mode 100644
index 00000000..5fda9548
Binary files /dev/null and b/doc/imgs_style_text/4.jpg differ
diff --git a/doc/imgs_style_text/5.png b/doc/imgs_style_text/5.png
new file mode 100644
index 00000000..ea0b8903
Binary files /dev/null and b/doc/imgs_style_text/5.png differ
diff --git a/style_text_rec/__init__.py b/style_text_rec/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/style_text_rec/arch/__init__.py b/style_text_rec/arch/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/style_text_rec/arch/base_module.py b/style_text_rec/arch/base_module.py
new file mode 100644
index 00000000..da2b6b83
--- /dev/null
+++ b/style_text_rec/arch/base_module.py
@@ -0,0 +1,255 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.spectral_norm import spectral_norm
+
+
+class CBN(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(CBN, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._conv = paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr)
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._conv(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class SNConv(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(SNConv, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._sn_conv = spectral_norm(
+ paddle.nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr))
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._sn_conv(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class SNConvTranspose(nn.Layer):
+ def __init__(self,
+ name,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ output_padding=0,
+ dilation=1,
+ groups=1,
+ use_bias=False,
+ norm_layer=None,
+ act=None,
+ act_attr=None):
+ super(SNConvTranspose, self).__init__()
+ if use_bias:
+ bias_attr = paddle.ParamAttr(name=name + "_bias")
+ else:
+ bias_attr = None
+ self._sn_conv_transpose = spectral_norm(
+ paddle.nn.Conv2DTranspose(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ dilation=dilation,
+ groups=groups,
+ weight_attr=paddle.ParamAttr(name=name + "_weights"),
+ bias_attr=bias_attr))
+ if norm_layer:
+ self._norm_layer = getattr(paddle.nn, norm_layer)(
+ num_features=out_channels, name=name + "_bn")
+ else:
+ self._norm_layer = None
+ if act:
+ if act_attr:
+ self._act = getattr(paddle.nn, act)(**act_attr,
+ name=name + "_" + act)
+ else:
+ self._act = getattr(paddle.nn, act)(name=name + "_" + act)
+ else:
+ self._act = None
+
+ def forward(self, x):
+ out = self._sn_conv_transpose(x)
+ if self._norm_layer:
+ out = self._norm_layer(out)
+ if self._act:
+ out = self._act(out)
+ return out
+
+
+class MiddleNet(nn.Layer):
+ def __init__(self, name, in_channels, mid_channels, out_channels,
+ use_bias):
+ super(MiddleNet, self).__init__()
+ self._sn_conv1 = SNConv(
+ name=name + "_sn_conv1",
+ in_channels=in_channels,
+ out_channels=mid_channels,
+ kernel_size=1,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=None)
+ self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
+ self._sn_conv2 = SNConv(
+ name=name + "_sn_conv2",
+ in_channels=mid_channels,
+ out_channels=mid_channels,
+ kernel_size=3,
+ use_bias=use_bias)
+ self._sn_conv3 = SNConv(
+ name=name + "_sn_conv3",
+ in_channels=mid_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ use_bias=use_bias)
+
+ def forward(self, x):
+
+ sn_conv1 = self._sn_conv1.forward(x)
+ pad_2d = self._pad2d.forward(sn_conv1)
+ sn_conv2 = self._sn_conv2.forward(pad_2d)
+ sn_conv3 = self._sn_conv3.forward(sn_conv2)
+ return sn_conv3
+
+
+class ResBlock(nn.Layer):
+ def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
+ use_bias):
+ super(ResBlock, self).__init__()
+ if use_dilation:
+ padding_mat = [1, 1, 1, 1]
+ else:
+ padding_mat = [0, 0, 0, 0]
+ self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
+
+ self._sn_conv1 = SNConv(
+ name=name + "_sn_conv1",
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=3,
+ padding=0,
+ norm_layer=norm_layer,
+ use_bias=use_bias,
+ act="ReLU",
+ act_attr=None)
+ if use_dropout:
+ self._dropout = nn.Dropout(0.5)
+ else:
+ self._dropout = None
+ self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._sn_conv2 = SNConv(
+ name=name + "_sn_conv2",
+ in_channels=channels,
+ out_channels=channels,
+ kernel_size=3,
+ norm_layer=norm_layer,
+ use_bias=use_bias,
+ act="ReLU",
+ act_attr=None)
+
+ def forward(self, x):
+ pad1 = self._pad1.forward(x)
+ sn_conv1 = self._sn_conv1.forward(pad1)
+ pad2 = self._pad2.forward(sn_conv1)
+ sn_conv2 = self._sn_conv2.forward(pad2)
+ return sn_conv2 + x
diff --git a/style_text_rec/arch/decoder.py b/style_text_rec/arch/decoder.py
new file mode 100644
index 00000000..36f07c59
--- /dev/null
+++ b/style_text_rec/arch/decoder.py
@@ -0,0 +1,251 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import SNConv, SNConvTranspose, ResBlock
+
+
+class Decoder(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(Decoder, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 8,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self.conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x):
+ if isinstance(x, (list, tuple)):
+ x = paddle.concat(x, axis=1)
+ output_dict = dict()
+ output_dict["conv_blocks"] = self.conv_blocks.forward(x)
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(output_dict["up1"])
+ output_dict["up3"] = self._up3.forward(output_dict["up2"])
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
+
+
+class DecoderUnet(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(DecoderUnet, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 8,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x, y, feature2, feature1):
+ output_dict = dict()
+ output_dict["conv_blocks"] = self._conv_blocks(
+ paddle.concat(
+ (x, y), axis=1))
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(
+ paddle.concat(
+ (output_dict["up1"], feature2), axis=1))
+ output_dict["up3"] = self._up3.forward(
+ paddle.concat(
+ (output_dict["up2"], feature1), axis=1))
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
+
+
+class SingleDecoder(nn.Layer):
+ def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation, out_conv_act, out_conv_act_attr):
+ super(SingleDecoder, self).__init__()
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 4,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 8,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up3 = SNConvTranspose(
+ name=name + "_up3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ output_padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
+ self._out_conv = SNConv(
+ name=name + "_out_conv",
+ in_channels=encode_dim,
+ out_channels=out_channels,
+ kernel_size=3,
+ use_bias=use_bias,
+ norm_layer=None,
+ act=out_conv_act,
+ act_attr=out_conv_act_attr)
+
+ def forward(self, x, feature2, feature1):
+ output_dict = dict()
+ output_dict["conv_blocks"] = self._conv_blocks.forward(x)
+ output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
+ output_dict["up2"] = self._up2.forward(
+ paddle.concat(
+ (output_dict["up1"], feature2), axis=1))
+ output_dict["up3"] = self._up3.forward(
+ paddle.concat(
+ (output_dict["up2"], feature1), axis=1))
+ output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
+ output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
+ return output_dict
diff --git a/style_text_rec/arch/encoder.py b/style_text_rec/arch/encoder.py
new file mode 100644
index 00000000..b884cda2
--- /dev/null
+++ b/style_text_rec/arch/encoder.py
@@ -0,0 +1,186 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import SNConv, SNConvTranspose, ResBlock
+
+
+class Encoder(nn.Layer):
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
+ act, act_attr, conv_block_dropout, conv_block_num,
+ conv_block_dilation):
+ super(Encoder, self).__init__()
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
+ self._in_conv = SNConv(
+ name=name + "_in_conv",
+ in_channels=in_channels,
+ out_channels=encode_dim,
+ kernel_size=7,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down1 = SNConv(
+ name=name + "_down1",
+ in_channels=encode_dim,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down2 = SNConv(
+ name=name + "_down2",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down3 = SNConv(
+ name=name + "_down3",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ conv_blocks = []
+ for i in range(conv_block_num):
+ conv_blocks.append(
+ ResBlock(
+ name="{}_conv_block_{}".format(name, i),
+ channels=encode_dim * 4,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias))
+ self._conv_blocks = nn.Sequential(*conv_blocks)
+
+ def forward(self, x):
+ out_dict = dict()
+ x = self._pad2d(x)
+ out_dict["in_conv"] = self._in_conv.forward(x)
+ out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
+ out_dict["down2"] = self._down2.forward(out_dict["down1"])
+ out_dict["down3"] = self._down3.forward(out_dict["down2"])
+ out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
+ return out_dict
+
+
+class EncoderUnet(nn.Layer):
+ def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
+ act, act_attr):
+ super(EncoderUnet, self).__init__()
+ self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
+ self._in_conv = SNConv(
+ name=name + "_in_conv",
+ in_channels=in_channels,
+ out_channels=encode_dim,
+ kernel_size=7,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down1 = SNConv(
+ name=name + "_down1",
+ in_channels=encode_dim,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down2 = SNConv(
+ name=name + "_down2",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down3 = SNConv(
+ name=name + "_down3",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._down4 = SNConv(
+ name=name + "_down4",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up1 = SNConvTranspose(
+ name=name + "_up1",
+ in_channels=encode_dim * 2,
+ out_channels=encode_dim * 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+ self._up2 = SNConvTranspose(
+ name=name + "_up2",
+ in_channels=encode_dim * 4,
+ out_channels=encode_dim * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act=act,
+ act_attr=act_attr)
+
+ def forward(self, x):
+ output_dict = dict()
+ x = self._pad2d(x)
+ output_dict['in_conv'] = self._in_conv.forward(x)
+ output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
+ output_dict['down2'] = self._down2.forward(output_dict['down1'])
+ output_dict['down3'] = self._down3.forward(output_dict['down2'])
+ output_dict['down4'] = self._down4.forward(output_dict['down3'])
+ output_dict['up1'] = self._up1.forward(output_dict['down4'])
+ output_dict['up2'] = self._up2.forward(
+ paddle.concat(
+ (output_dict['down3'], output_dict['up1']), axis=1))
+ output_dict['concat'] = paddle.concat(
+ (output_dict['down2'], output_dict['up2']), axis=1)
+ return output_dict
diff --git a/style_text_rec/arch/spectral_norm.py b/style_text_rec/arch/spectral_norm.py
new file mode 100644
index 00000000..21d0afc8
--- /dev/null
+++ b/style_text_rec/arch/spectral_norm.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+def normal_(x, mean=0., std=1.):
+ temp_value = paddle.normal(mean, std, shape=x.shape)
+ x.set_value(temp_value)
+ return x
+
+
+class SpectralNorm(object):
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError('Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(
+ n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # transpose dim to front
+ weight_mat = weight_mat.transpose([
+ self.dim,
+ * [d for d in range(weight_mat.dim()) if d != self.dim]
+ ])
+
+ height = weight_mat.shape[0]
+
+ return weight_mat.reshape([height, -1])
+
+ def compute_weight(self, module, do_power_iteration):
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with paddle.no_grad():
+ for _ in range(self.n_power_iterations):
+ v.set_value(
+ F.normalize(
+ paddle.matmul(
+ weight_mat,
+ u,
+ transpose_x=True,
+ transpose_y=False),
+ axis=0,
+ epsilon=self.eps, ))
+
+ u.set_value(
+ F.normalize(
+ paddle.matmul(weight_mat, v),
+ axis=0,
+ epsilon=self.eps, ))
+ if self.n_power_iterations > 0:
+ u = u.clone()
+ v = v.clone()
+
+ sigma = paddle.dot(u, paddle.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with paddle.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+
+ module.add_parameter(self.name, weight.detach())
+
+ def __call__(self, module, inputs):
+ setattr(
+ module,
+ self.name,
+ self.compute_weight(
+ module, do_power_iteration=module.training))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with paddle.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+ h, w = weight_mat.shape
+
+ # randomly initialize u and v
+ u = module.create_parameter([h])
+ u = normal_(u, 0., 1.)
+ v = module.create_parameter([w])
+ v = normal_(v, 0., 1.)
+ u = F.normalize(u, axis=0, epsilon=fn.eps)
+ v = F.normalize(v, axis=0, epsilon=fn.eps)
+
+ # delete fn.name form parameters, otherwise you can not set attribute
+ del module._parameters[fn.name]
+ module.add_parameter(fn.name + "_orig", weight)
+ # still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an Parameter and
+ # gets added as a parameter. Instead, we register weight * 1.0 as a plain
+ # attribute.
+ setattr(module, fn.name, weight * 1.0)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+ return fn
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+
+ if dim is None:
+ if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
+ nn.Conv3DTranspose, nn.Linear)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
diff --git a/style_text_rec/arch/style_text_rec.py b/style_text_rec/arch/style_text_rec.py
new file mode 100644
index 00000000..599927ce
--- /dev/null
+++ b/style_text_rec/arch/style_text_rec.py
@@ -0,0 +1,285 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+import paddle.nn as nn
+
+from arch.base_module import MiddleNet, ResBlock
+from arch.encoder import Encoder
+from arch.decoder import Decoder, DecoderUnet, SingleDecoder
+from utils.load_params import load_dygraph_pretrain
+from utils.logging import get_logger
+
+
+class StyleTextRec(nn.Layer):
+ def __init__(self, config):
+ super(StyleTextRec, self).__init__()
+ self.logger = get_logger()
+ self.text_generator = TextGenerator(config["Predictor"][
+ "text_generator"])
+ self.bg_generator = BgGeneratorWithMask(config["Predictor"][
+ "bg_generator"])
+ self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
+ "fusion_generator"])
+ bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
+ text_generator_pretrain = config["Predictor"]["text_generator"][
+ "pretrain"]
+ fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
+ "pretrain"]
+ load_dygraph_pretrain(
+ self.bg_generator,
+ self.logger,
+ path=bg_generator_pretrain,
+ load_static_weights=False)
+ load_dygraph_pretrain(
+ self.text_generator,
+ self.logger,
+ path=text_generator_pretrain,
+ load_static_weights=False)
+ load_dygraph_pretrain(
+ self.fusion_generator,
+ self.logger,
+ path=fusion_generator_pretrain,
+ load_static_weights=False)
+
+ def forward(self, style_input, text_input):
+ text_gen_output = self.text_generator.forward(style_input, text_input)
+ fake_text = text_gen_output["fake_text"]
+ fake_sk = text_gen_output["fake_sk"]
+ bg_gen_output = self.bg_generator.forward(style_input)
+ bg_encode_feature = bg_gen_output["bg_encode_feature"]
+ bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
+ bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
+ fake_bg = bg_gen_output["fake_bg"]
+
+ fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
+ fake_fusion = fusion_gen_output["fake_fusion"]
+ return {
+ "fake_fusion": fake_fusion,
+ "fake_text": fake_text,
+ "fake_sk": fake_sk,
+ "fake_bg": fake_bg,
+ }
+
+
+class TextGenerator(nn.Layer):
+ def __init__(self, config):
+ super(TextGenerator, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_num = config["conv_block_num"]
+ conv_block_dilation = config["conv_block_dilation"]
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+ self.encoder_text = Encoder(
+ name=name + "_encoder_text",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+ self.encoder_style = Encoder(
+ name=name + "_encoder_style",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+ self.decoder_text = Decoder(
+ name=name + "_decoder_text",
+ encode_dim=encode_dim,
+ out_channels=int(encode_dim / 2),
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Tanh",
+ out_conv_act_attr=None)
+ self.decoder_sk = Decoder(
+ name=name + "_decoder_sk",
+ encode_dim=encode_dim,
+ out_channels=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Sigmoid",
+ out_conv_act_attr=None)
+
+ self.middle = MiddleNet(
+ name=name + "_middle_net",
+ in_channels=int(encode_dim / 2) + 1,
+ mid_channels=encode_dim,
+ out_channels=3,
+ use_bias=use_bias)
+
+ def forward(self, style_input, text_input):
+ style_feature = self.encoder_style.forward(style_input)["res_blocks"]
+ text_feature = self.encoder_text.forward(text_input)["res_blocks"]
+ fake_c_temp = self.decoder_text.forward([text_feature,
+ style_feature])["out_conv"]
+ fake_sk = self.decoder_sk.forward([text_feature,
+ style_feature])["out_conv"]
+ fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
+ return {"fake_sk": fake_sk, "fake_text": fake_text}
+
+
+class BgGeneratorWithMask(nn.Layer):
+ def __init__(self, config):
+ super(BgGeneratorWithMask, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_num = config["conv_block_num"]
+ conv_block_dilation = config["conv_block_dilation"]
+ self.output_factor = config.get("output_factor", 1.0)
+
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+
+ self.encoder_bg = Encoder(
+ name=name + "_encoder_bg",
+ in_channels=3,
+ encode_dim=encode_dim,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation)
+
+ self.decoder_bg = SingleDecoder(
+ name=name + "_decoder_bg",
+ encode_dim=encode_dim,
+ out_channels=3,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Tanh",
+ out_conv_act_attr=None)
+
+ self.decoder_mask = Decoder(
+ name=name + "_decoder_mask",
+ encode_dim=encode_dim // 2,
+ out_channels=1,
+ use_bias=use_bias,
+ norm_layer=norm_layer,
+ act="ReLU",
+ act_attr=None,
+ conv_block_dropout=conv_block_dropout,
+ conv_block_num=conv_block_num,
+ conv_block_dilation=conv_block_dilation,
+ out_conv_act="Sigmoid",
+ out_conv_act_attr=None)
+
+ self.middle = MiddleNet(
+ name=name + "_middle_net",
+ in_channels=3 + 1,
+ mid_channels=encode_dim,
+ out_channels=3,
+ use_bias=use_bias)
+
+ def forward(self, style_input):
+ encode_bg_output = self.encoder_bg(style_input)
+ decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
+ encode_bg_output["down2"],
+ encode_bg_output["down1"])
+
+ fake_c_temp = decode_bg_output["out_conv"]
+ fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
+ "res_blocks"])["out_conv"]
+ fake_bg = self.middle(
+ paddle.concat(
+ (fake_c_temp, fake_bg_mask), axis=1))
+ return {
+ "bg_encode_feature": encode_bg_output["res_blocks"],
+ "bg_decode_feature1": decode_bg_output["up1"],
+ "bg_decode_feature2": decode_bg_output["up2"],
+ "fake_bg": fake_bg,
+ "fake_bg_mask": fake_bg_mask,
+ }
+
+
+class FusionGeneratorSimple(nn.Layer):
+ def __init__(self, config):
+ super(FusionGeneratorSimple, self).__init__()
+ name = config["module_name"]
+ encode_dim = config["encode_dim"]
+ norm_layer = config["norm_layer"]
+ conv_block_dropout = config["conv_block_dropout"]
+ conv_block_dilation = config["conv_block_dilation"]
+ if norm_layer == "InstanceNorm2D":
+ use_bias = True
+ else:
+ use_bias = False
+
+ self._conv = nn.Conv2D(
+ in_channels=6,
+ out_channels=encode_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
+ bias_attr=False)
+
+ self._res_block = ResBlock(
+ name="{}_conv_block".format(name),
+ channels=encode_dim,
+ norm_layer=norm_layer,
+ use_dropout=conv_block_dropout,
+ use_dilation=conv_block_dilation,
+ use_bias=use_bias)
+
+ self._reduce_conv = nn.Conv2D(
+ in_channels=encode_dim,
+ out_channels=3,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ groups=1,
+ weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
+ bias_attr=False)
+
+ def forward(self, fake_text, fake_bg):
+ fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
+ fake_concat_tmp = self._conv(fake_concat)
+ output_res = self._res_block(fake_concat_tmp)
+ fake_fusion = self._reduce_conv(output_res)
+ return {"fake_fusion": fake_fusion}
diff --git a/style_text_rec/configs/config.yml b/style_text_rec/configs/config.yml
new file mode 100644
index 00000000..3b10b3d2
--- /dev/null
+++ b/style_text_rec/configs/config.yml
@@ -0,0 +1,54 @@
+Global:
+ output_num: 10
+ output_dir: output_data
+ use_gpu: false
+ image_height: 32
+ image_width: 320
+TextDrawer:
+ fonts:
+ en: fonts/en_standard.ttf
+ ch: fonts/ch_standard.ttf
+ ko: fonts/ko_standard.ttf
+Predictor:
+ method: StyleTextRecPredictor
+ algorithm: StyleTextRec
+ scale: 0.00392156862745098
+ mean:
+ - 0.5
+ - 0.5
+ - 0.5
+ std:
+ - 0.5
+ - 0.5
+ - 0.5
+ expand_result: false
+ bg_generator:
+ pretrain: style_text_models/bg_generator
+ module_name: bg_generator
+ generator_type: BgGeneratorWithMask
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ output_factor: 1.05
+ text_generator:
+ pretrain: style_text_models/text_generator
+ module_name: text_generator
+ generator_type: TextGenerator
+ encode_dim: 64
+ norm_layer: InstanceNorm2D
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ fusion_generator:
+ pretrain: style_text_models/fusion_generator
+ module_name: fusion_generator
+ generator_type: FusionGeneratorSimple
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+Writer:
+ method: SimpleWriter
diff --git a/style_text_rec/configs/dataset_config.yml b/style_text_rec/configs/dataset_config.yml
new file mode 100644
index 00000000..e047489e
--- /dev/null
+++ b/style_text_rec/configs/dataset_config.yml
@@ -0,0 +1,64 @@
+Global:
+ output_num: 10
+ output_dir: output_data
+ use_gpu: false
+ image_height: 32
+ image_width: 320
+ standard_font: fonts/en_standard.ttf
+TextDrawer:
+ fonts:
+ en: fonts/en_standard.ttf
+ ch: fonts/ch_standard.ttf
+ ko: fonts/ko_standard.ttf
+StyleSampler:
+ method: DatasetSampler
+ image_home: examples
+ label_file: examples/image_list.txt
+ with_label: true
+CorpusGenerator:
+ method: FileCorpus
+ language: ch
+ corpus_file: examples/corpus/example.txt
+Predictor:
+ method: StyleTextRecPredictor
+ algorithm: StyleTextRec
+ scale: 0.00392156862745098
+ mean:
+ - 0.5
+ - 0.5
+ - 0.5
+ std:
+ - 0.5
+ - 0.5
+ - 0.5
+ expand_result: false
+ bg_generator:
+ pretrain: models/style_text_rec/bg_generator
+ module_name: bg_generator
+ generator_type: BgGeneratorWithMask
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ output_factor: 1.05
+ text_generator:
+ pretrain: models/style_text_rec/text_generator
+ module_name: text_generator
+ generator_type: TextGenerator
+ encode_dim: 64
+ norm_layer: InstanceNorm2D
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+ fusion_generator:
+ pretrain: models/style_text_rec/fusion_generator
+ module_name: fusion_generator
+ generator_type: FusionGeneratorSimple
+ encode_dim: 64
+ norm_layer: null
+ conv_block_num: 4
+ conv_block_dropout: false
+ conv_block_dilation: true
+Writer:
+ method: SimpleWriter
diff --git a/style_text_rec/engine/__init__.py b/style_text_rec/engine/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/style_text_rec/engine/corpus_generators.py b/style_text_rec/engine/corpus_generators.py
new file mode 100644
index 00000000..186d15f3
--- /dev/null
+++ b/style_text_rec/engine/corpus_generators.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import random
+
+from utils.logging import get_logger
+
+
+class FileCorpus(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.logger.info("using FileCorpus")
+
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+
+ corpus_file = config["CorpusGenerator"]["corpus_file"]
+ self.language = config["CorpusGenerator"]["language"]
+ with open(corpus_file, 'r') as f:
+ corpus_raw = f.read()
+ self.corpus_list = corpus_raw.split("\n")[:-1]
+ assert len(self.corpus_list) > 0
+ random.shuffle(self.corpus_list)
+ self.index = 0
+
+ def generate(self, corpus_length=0):
+ if self.index >= len(self.corpus_list):
+ self.index = 0
+ random.shuffle(self.corpus_list)
+ corpus = self.corpus_list[self.index]
+ if corpus_length != 0:
+ corpus = corpus[0:corpus_length]
+ if corpus_length > len(corpus):
+ self.logger.warning("generated corpus is shorter than expected.")
+ self.index += 1
+ return self.language, corpus
+
+
+class EnNumCorpus(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.logger.info("using NumberCorpus")
+ self.num_list = "0123456789"
+ self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.height = config["Global"]["image_height"]
+ self.max_width = config["Global"]["image_width"]
+
+ def generate(self, corpus_length=0):
+ corpus = ""
+ if corpus_length == 0:
+ corpus_length = random.randint(5, 15)
+ for i in range(corpus_length):
+ if random.random() < 0.2:
+ corpus += "{}".format(random.choice(self.en_char_list))
+ else:
+ corpus += "{}".format(random.choice(self.num_list))
+ return "en", corpus
diff --git a/style_text_rec/engine/predictors.py b/style_text_rec/engine/predictors.py
new file mode 100644
index 00000000..d9f4afe4
--- /dev/null
+++ b/style_text_rec/engine/predictors.py
@@ -0,0 +1,115 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import cv2
+import math
+import paddle
+
+from arch import style_text_rec
+from utils.sys_funcs import check_gpu
+from utils.logging import get_logger
+
+
+class StyleTextRecPredictor(object):
+ def __init__(self, config):
+ algorithm = config['Predictor']['algorithm']
+ assert algorithm in ["StyleTextRec"
+ ], "Generator {} not supported.".format(algorithm)
+ use_gpu = config["Global"]['use_gpu']
+ check_gpu(use_gpu)
+ self.logger = get_logger()
+ self.generator = getattr(style_text_rec, algorithm)(config)
+ self.height = config["Global"]["image_height"]
+ self.width = config["Global"]["image_width"]
+ self.scale = config["Predictor"]["scale"]
+ self.mean = config["Predictor"]["mean"]
+ self.std = config["Predictor"]["std"]
+ self.expand_result = config["Predictor"]["expand_result"]
+
+ def predict(self, style_input, text_input):
+ style_input = self.rep_style_input(style_input, text_input)
+ tensor_style_input = self.preprocess(style_input)
+ tensor_text_input = self.preprocess(text_input)
+ style_text_result = self.generator.forward(tensor_style_input,
+ tensor_text_input)
+ fake_fusion = self.postprocess(style_text_result["fake_fusion"])
+ fake_text = self.postprocess(style_text_result["fake_text"])
+ fake_sk = self.postprocess(style_text_result["fake_sk"])
+ fake_bg = self.postprocess(style_text_result["fake_bg"])
+ bbox = self.get_text_boundary(fake_text)
+ if bbox:
+ left, right, top, bottom = bbox
+ fake_fusion = fake_fusion[top:bottom, left:right, :]
+ fake_text = fake_text[top:bottom, left:right, :]
+ fake_sk = fake_sk[top:bottom, left:right, :]
+ fake_bg = fake_bg[top:bottom, left:right, :]
+
+ # fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
+ return {
+ "fake_fusion": fake_fusion,
+ "fake_text": fake_text,
+ "fake_sk": fake_sk,
+ "fake_bg": fake_bg,
+ }
+
+ def preprocess(self, img):
+ img = (img.astype('float32') * self.scale - self.mean) / self.std
+ img_height, img_width, channel = img.shape
+ assert channel == 3, "Please use an rgb image."
+ ratio = img_width / float(img_height)
+ if math.ceil(self.height * ratio) > self.width:
+ resized_w = self.width
+ else:
+ resized_w = int(math.ceil(self.height * ratio))
+ img = cv2.resize(img, (resized_w, self.height))
+
+ new_img = np.zeros([self.height, self.width, 3]).astype('float32')
+ new_img[:, 0:resized_w, :] = img
+ img = new_img.transpose((2, 0, 1))
+ img = img[np.newaxis, :, :, :]
+ return paddle.to_tensor(img)
+
+ def postprocess(self, tensor):
+ img = tensor.numpy()[0]
+ img = img.transpose((1, 2, 0))
+ img = (img * self.std + self.mean) / self.scale
+ img = np.maximum(img, 0.0)
+ img = np.minimum(img, 255.0)
+ img = img.astype('uint8')
+ return img
+
+ def rep_style_input(self, style_input, text_input):
+ rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
+ (style_input.shape[1] / style_input.shape[0])) + 1
+ style_input = np.tile(style_input, reps=[1, rep_num, 1])
+ max_width = int(self.width / self.height * style_input.shape[0])
+ style_input = style_input[:, :max_width, :]
+ return style_input
+
+ def get_text_boundary(self, text_img):
+ img_height = text_img.shape[0]
+ img_width = text_img.shape[1]
+ bounder = 3
+ text_canny_img = cv2.Canny(text_img, 10, 20)
+ edge_num_h = text_canny_img.sum(axis=0)
+ no_zero_list_h = np.where(edge_num_h > 0)[0]
+ edge_num_w = text_canny_img.sum(axis=1)
+ no_zero_list_w = np.where(edge_num_w > 0)[0]
+ if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
+ return None
+ left = max(no_zero_list_h[0] - bounder, 0)
+ right = min(no_zero_list_h[-1] + bounder, img_width)
+ top = max(no_zero_list_w[0] - bounder, 0)
+ bottom = min(no_zero_list_w[-1] + bounder, img_height)
+ return [left, right, top, bottom]
diff --git a/style_text_rec/engine/style_samplers.py b/style_text_rec/engine/style_samplers.py
new file mode 100644
index 00000000..e171d58d
--- /dev/null
+++ b/style_text_rec/engine/style_samplers.py
@@ -0,0 +1,62 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import random
+import cv2
+
+
+class DatasetSampler(object):
+ def __init__(self, config):
+ self.image_home = config["StyleSampler"]["image_home"]
+ label_file = config["StyleSampler"]["label_file"]
+ self.dataset_with_label = config["StyleSampler"]["with_label"]
+ self.height = config["Global"]["image_height"]
+ self.index = 0
+ with open(label_file, "r") as f:
+ label_raw = f.read()
+ self.path_label_list = label_raw.split("\n")[:-1]
+ assert len(self.path_label_list) > 0
+ random.shuffle(self.path_label_list)
+
+ def sample(self):
+ if self.index >= len(self.path_label_list):
+ random.shuffle(self.path_label_list)
+ self.index = 0
+ if self.dataset_with_label:
+ path_label = self.path_label_list[self.index]
+ rel_image_path, label = path_label.split('\t')
+ else:
+ rel_image_path = self.path_label_list[self.index]
+ label = None
+ img_path = "{}/{}".format(self.image_home, rel_image_path)
+ image = cv2.imread(img_path)
+ origin_height = image.shape[0]
+ ratio = self.height / origin_height
+ width = int(image.shape[1] * ratio)
+ height = int(image.shape[0] * ratio)
+ image = cv2.resize(image, (width, height))
+
+ self.index += 1
+ if label:
+ return {"image": image, "label": label}
+ else:
+ return {"image": image}
+
+
+def duplicate_image(image, width):
+ image_width = image.shape[1]
+ dup_num = width // image_width + 1
+ image = np.tile(image, reps=[1, dup_num, 1])
+ cropped_image = image[:, :width, :]
+ return cropped_image
diff --git a/style_text_rec/engine/synthesisers.py b/style_text_rec/engine/synthesisers.py
new file mode 100644
index 00000000..177e3e04
--- /dev/null
+++ b/style_text_rec/engine/synthesisers.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+
+from utils.config import ArgsParser, load_config, override_config
+from utils.logging import get_logger
+from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
+
+
+class ImageSynthesiser(object):
+ def __init__(self):
+ self.FLAGS = ArgsParser().parse_args()
+ self.config = load_config(self.FLAGS.config)
+ self.config = override_config(self.config, options=self.FLAGS.override)
+ self.output_dir = self.config["Global"]["output_dir"]
+ if not os.path.exists(self.output_dir):
+ os.mkdir(self.output_dir)
+ self.logger = get_logger(
+ log_file='{}/predict.log'.format(self.output_dir))
+
+ self.text_drawer = text_drawers.StdTextDrawer(self.config)
+
+ predictor_method = self.config["Predictor"]["method"]
+ assert predictor_method is not None
+ self.predictor = getattr(predictors, predictor_method)(self.config)
+
+ def synth_image(self, corpus, style_input, language="en"):
+ corpus, text_input = self.text_drawer.draw_text(corpus, language)
+ synth_result = self.predictor.predict(style_input, text_input)
+ return synth_result
+
+
+class DatasetSynthesiser(ImageSynthesiser):
+ def __init__(self):
+ super(DatasetSynthesiser, self).__init__()
+ self.tag = self.FLAGS.tag
+ self.output_num = self.config["Global"]["output_num"]
+ corpus_generator_method = self.config["CorpusGenerator"]["method"]
+ self.corpus_generator = getattr(corpus_generators,
+ corpus_generator_method)(self.config)
+
+ style_sampler_method = self.config["StyleSampler"]["method"]
+ assert style_sampler_method is not None
+ self.style_sampler = style_samplers.DatasetSampler(self.config)
+ self.writer = writers.SimpleWriter(self.config, self.tag)
+
+ def synth_dataset(self):
+ for i in range(self.output_num):
+ style_data = self.style_sampler.sample()
+ style_input = style_data["image"]
+ corpus_language, text_input_label = self.corpus_generator.generate(
+ )
+ text_input_label, text_input = self.text_drawer.draw_text(
+ text_input_label, corpus_language)
+
+ synth_result = self.predictor.predict(style_input, text_input)
+ fake_fusion = synth_result["fake_fusion"]
+ self.writer.save_image(fake_fusion, text_input_label)
+ self.writer.save_label()
+ self.writer.merge_label()
diff --git a/style_text_rec/engine/text_drawers.py b/style_text_rec/engine/text_drawers.py
new file mode 100644
index 00000000..8aaac06e
--- /dev/null
+++ b/style_text_rec/engine/text_drawers.py
@@ -0,0 +1,57 @@
+from PIL import Image, ImageDraw, ImageFont
+import numpy as np
+from utils.logging import get_logger
+
+
+class StdTextDrawer(object):
+ def __init__(self, config):
+ self.logger = get_logger()
+ self.max_width = config["Global"]["image_width"]
+ self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
+ self.height = config["Global"]["image_height"]
+ self.font_dict = {}
+ self.load_fonts(config["TextDrawer"]["fonts"])
+ self.support_languages = list(self.font_dict)
+
+ def load_fonts(self, fonts_config):
+ for language in fonts_config:
+ font_path = fonts_config[language]
+ font_height = self.get_valid_height(font_path)
+ font = ImageFont.truetype(font_path, font_height)
+ self.font_dict[language] = font
+
+ def get_valid_height(self, font_path):
+ font = ImageFont.truetype(font_path, self.height - 4)
+ _, font_height = font.getsize(self.char_list)
+ if font_height <= self.height - 4:
+ return self.height - 4
+ else:
+ return int((self.height - 4)**2 / font_height)
+
+ def draw_text(self, corpus, language="en", crop=True):
+ if language not in self.support_languages:
+ self.logger.warning(
+ "language {} not supported, use en instead.".format(language))
+ language = "en"
+ if crop:
+ width = min(self.max_width, len(corpus) * self.height) + 4
+ else:
+ width = len(corpus) * self.height + 4
+ bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
+ draw = ImageDraw.Draw(bg)
+
+ char_x = 2
+ font = self.font_dict[language]
+ for i, char_i in enumerate(corpus):
+ char_size = font.getsize(char_i)[0]
+ draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
+ char_x += char_size
+ if char_x >= width:
+ corpus = corpus[0:i + 1]
+ self.logger.warning("corpus length exceed limit: {}".format(
+ corpus))
+ break
+
+ text_input = np.array(bg).astype(np.uint8)
+ text_input = text_input[:, 0:char_x, :]
+ return corpus, text_input
diff --git a/style_text_rec/engine/writers.py b/style_text_rec/engine/writers.py
new file mode 100644
index 00000000..0df75e72
--- /dev/null
+++ b/style_text_rec/engine/writers.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import glob
+
+from utils.logging import get_logger
+
+
+class SimpleWriter(object):
+ def __init__(self, config, tag):
+ self.logger = get_logger()
+ self.output_dir = config["Global"]["output_dir"]
+ self.counter = 0
+ self.label_dict = {}
+ self.tag = tag
+ self.label_file_index = 0
+
+ def save_image(self, image, text_input_label):
+ image_home = os.path.join(self.output_dir, "images", self.tag)
+ if not os.path.exists(image_home):
+ os.makedirs(image_home)
+
+ image_path = os.path.join(image_home, "{}.png".format(self.counter))
+ # todo support continue synth
+ cv2.imwrite(image_path, image)
+ self.logger.info("generate image: {}".format(image_path))
+
+ image_name = os.path.join(self.tag, "{}.png".format(self.counter))
+ self.label_dict[image_name] = text_input_label
+
+ self.counter += 1
+ if not self.counter % 100:
+ self.save_label()
+
+ def save_label(self):
+ label_raw = ""
+ label_home = os.path.join(self.output_dir, "label")
+ if not os.path.exists(label_home):
+ os.mkdir(label_home)
+ for image_path in self.label_dict:
+ label = self.label_dict[image_path]
+ label_raw += "{}\t{}\n".format(image_path, label)
+ label_file_path = os.path.join(label_home,
+ "{}_label.txt".format(self.tag))
+ with open(label_file_path, "w") as f:
+ f.write(label_raw)
+ self.label_file_index += 1
+
+ def merge_label(self):
+ label_raw = ""
+ label_file_regex = os.path.join(self.output_dir, "label",
+ "*_label.txt")
+ label_file_list = glob.glob(label_file_regex)
+ for label_file_i in label_file_list:
+ with open(label_file_i, "r") as f:
+ label_raw += f.read()
+ label_file_path = os.path.join(self.output_dir, "label.txt")
+ with open(label_file_path, "w") as f:
+ f.write(label_raw)
diff --git a/style_text_rec/examples/corpus/example.txt b/style_text_rec/examples/corpus/example.txt
new file mode 100644
index 00000000..78451cc3
--- /dev/null
+++ b/style_text_rec/examples/corpus/example.txt
@@ -0,0 +1,2 @@
+PaddleOCR
+飞桨文字识别
diff --git a/style_text_rec/examples/image_list.txt b/style_text_rec/examples/image_list.txt
new file mode 100644
index 00000000..b07be035
--- /dev/null
+++ b/style_text_rec/examples/image_list.txt
@@ -0,0 +1,2 @@
+style_images/1.jpg NEATNESS
+style_images/2.jpg 锁店君和宾馆
diff --git a/style_text_rec/examples/style_images/1.jpg b/style_text_rec/examples/style_images/1.jpg
new file mode 100644
index 00000000..4da7838e
Binary files /dev/null and b/style_text_rec/examples/style_images/1.jpg differ
diff --git a/style_text_rec/examples/style_images/2.jpg b/style_text_rec/examples/style_images/2.jpg
new file mode 100644
index 00000000..f68ce49a
Binary files /dev/null and b/style_text_rec/examples/style_images/2.jpg differ
diff --git a/style_text_rec/fonts/ch_standard.ttf b/style_text_rec/fonts/ch_standard.ttf
new file mode 100755
index 00000000..cdb7fa59
Binary files /dev/null and b/style_text_rec/fonts/ch_standard.ttf differ
diff --git a/style_text_rec/fonts/en_standard.ttf b/style_text_rec/fonts/en_standard.ttf
new file mode 100755
index 00000000..2e31d024
Binary files /dev/null and b/style_text_rec/fonts/en_standard.ttf differ
diff --git a/style_text_rec/fonts/ko_standard.ttf b/style_text_rec/fonts/ko_standard.ttf
new file mode 100755
index 00000000..982bd879
Binary files /dev/null and b/style_text_rec/fonts/ko_standard.ttf differ
diff --git a/style_text_rec/tools/__init__.py b/style_text_rec/tools/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/style_text_rec/tools/synth_dataset.py b/style_text_rec/tools/synth_dataset.py
new file mode 100644
index 00000000..4a0e6d5e
--- /dev/null
+++ b/style_text_rec/tools/synth_dataset.py
@@ -0,0 +1,23 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from engine.synthesisers import DatasetSynthesiser
+
+
+def synth_dataset():
+ dataset_synthesiser = DatasetSynthesiser()
+ dataset_synthesiser.synth_dataset()
+
+
+if __name__ == '__main__':
+ synth_dataset()
diff --git a/style_text_rec/tools/synth_image.py b/style_text_rec/tools/synth_image.py
new file mode 100644
index 00000000..7b4827b8
--- /dev/null
+++ b/style_text_rec/tools/synth_image.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import cv2
+import sys
+import glob
+
+from utils.config import ArgsParser
+from engine.synthesisers import ImageSynthesiser
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+
+def synth_image():
+ args = ArgsParser().parse_args()
+ image_synthesiser = ImageSynthesiser()
+ style_image_path = args.style_image
+ img = cv2.imread(style_image_path)
+ text_corpus = args.text_corpus
+ language = args.language
+
+ synth_result = image_synthesiser.synth_image(text_corpus, img, language)
+ fake_fusion = synth_result["fake_fusion"]
+ fake_text = synth_result["fake_text"]
+ fake_bg = synth_result["fake_bg"]
+ cv2.imwrite("fake_fusion.jpg", fake_fusion)
+ cv2.imwrite("fake_text.jpg", fake_text)
+ cv2.imwrite("fake_bg.jpg", fake_bg)
+
+
+def batch_synth_images():
+ image_synthesiser = ImageSynthesiser()
+
+ corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
+ style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
+ save_path = "./output_data/"
+ corpus_list = []
+ with open(corpus_file, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ substr = line.decode("utf-8").strip("\n").split("\t")
+ corpus_list.append(substr)
+ style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
+ corpus_num = len(corpus_list)
+ style_img_num = len(style_img_list)
+ for cno in range(corpus_num):
+ for sno in range(style_img_num):
+ corpus, lang = corpus_list[cno]
+ style_img_path = style_img_list[sno]
+ img = cv2.imread(style_img_path)
+ synth_result = image_synthesiser.synth_image(corpus, img, lang)
+ fake_fusion = synth_result["fake_fusion"]
+ fake_text = synth_result["fake_text"]
+ fake_bg = synth_result["fake_bg"]
+ for tp in range(2):
+ if tp == 0:
+ prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
+ else:
+ prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
+ cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
+ cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
+ cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
+ cv2.imwrite("%s_input_style.jpg" % prefix, img)
+ print(cno, corpus_num, sno, style_img_num)
+
+
+if __name__ == '__main__':
+ # batch_synth_images()
+ synth_image()
diff --git a/style_text_rec/utils/__init__.py b/style_text_rec/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/style_text_rec/utils/config.py b/style_text_rec/utils/config.py
new file mode 100644
index 00000000..b2f8a618
--- /dev/null
+++ b/style_text_rec/utils/config.py
@@ -0,0 +1,224 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import yaml
+import os
+from argparse import ArgumentParser, RawDescriptionHelpFormatter
+
+
+def override(dl, ks, v):
+ """
+ Recursively replace dict of list
+
+ Args:
+ dl(dict or list): dict or list to be replaced
+ ks(list): list of keys
+ v(str): value to be replaced
+ """
+
+ def str2num(v):
+ try:
+ return eval(v)
+ except Exception:
+ return v
+
+ assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
+ assert len(ks) > 0, ('lenght of keys should larger than 0')
+ if isinstance(dl, list):
+ k = str2num(ks[0])
+ if len(ks) == 1:
+ assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
+ dl[k] = str2num(v)
+ else:
+ override(dl[k], ks[1:], v)
+ else:
+ if len(ks) == 1:
+ #assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
+ if not ks[0] in dl:
+ logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
+ dl[ks[0]] = str2num(v)
+ else:
+ assert ks[0] in dl, (
+ '({}) doesn\'t exist in {}, a new dict field is invalid'.
+ format(ks[0], dl))
+ override(dl[ks[0]], ks[1:], v)
+
+
+def override_config(config, options=None):
+ """
+ Recursively override the config
+
+ Args:
+ config(dict): dict to be replaced
+ options(list): list of pairs(key0.key1.idx.key2=value)
+ such as: [
+ 'topk=2',
+ 'VALID.transforms.1.ResizeImage.resize_short=300'
+ ]
+
+ Returns:
+ config(dict): replaced config
+ """
+ if options is not None:
+ for opt in options:
+ assert isinstance(opt, str), (
+ "option({}) should be a str".format(opt))
+ assert "=" in opt, (
+ "option({}) should contain a ="
+ "to distinguish between key and value".format(opt))
+ pair = opt.split('=')
+ assert len(pair) == 2, ("there can be only a = in the option")
+ key, value = pair
+ keys = key.split('.')
+ override(config, keys, value)
+
+ return config
+
+
+class ArgsParser(ArgumentParser):
+ def __init__(self):
+ super(ArgsParser, self).__init__(
+ formatter_class=RawDescriptionHelpFormatter)
+ self.add_argument("-c", "--config", help="configuration file to use")
+ self.add_argument(
+ "-t", "--tag", default="0", help="tag for marking worker")
+ self.add_argument(
+ '-o',
+ '--override',
+ action='append',
+ default=[],
+ help='config options to be overridden')
+ self.add_argument(
+ "--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
+ self.add_argument(
+ "--text_corpus", default="PaddleOCR", help="tag for marking worker")
+ self.add_argument(
+ "--language", default="en", help="tag for marking worker")
+
+ def parse_args(self, argv=None):
+ args = super(ArgsParser, self).parse_args(argv)
+ assert args.config is not None, \
+ "Please specify --config=configure_file_path."
+ return args
+
+
+def load_config(file_path):
+ """
+ Load config from yml/yaml file.
+ Args:
+ file_path (str): Path of the config file to be loaded.
+ Returns: config
+ """
+ ext = os.path.splitext(file_path)[1]
+ assert ext in ['.yml', '.yaml'], "only support yaml files for now"
+ with open(file_path, 'rb') as f:
+ config = yaml.load(f, Loader=yaml.Loader)
+
+ return config
+
+
+def gen_config():
+ base_config = {
+ "Global": {
+ "algorithm": "SRNet",
+ "use_gpu": True,
+ "start_epoch": 1,
+ "stage1_epoch_num": 100,
+ "stage2_epoch_num": 100,
+ "log_smooth_window": 20,
+ "print_batch_step": 2,
+ "save_model_dir": "./output/SRNet",
+ "use_visualdl": False,
+ "save_epoch_step": 10,
+ "vgg_pretrain": "./pretrained/VGG19_pretrained",
+ "vgg_load_static_pretrain": True
+ },
+ "Architecture": {
+ "model_type": "data_aug",
+ "algorithm": "SRNet",
+ "net_g": {
+ "name": "srnet_net_g",
+ "encode_dim": 64,
+ "norm": "batch",
+ "use_dropout": False,
+ "init_type": "xavier",
+ "init_gain": 0.02,
+ "use_dilation": 1
+ },
+ # input_nc, ndf, netD,
+ # n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
+ "bg_discriminator": {
+ "name": "srnet_bg_discriminator",
+ "input_nc": 6,
+ "ndf": 64,
+ "netD": "basic",
+ "norm": "none",
+ "init_type": "xavier",
+ },
+ "fusion_discriminator": {
+ "name": "srnet_fusion_discriminator",
+ "input_nc": 6,
+ "ndf": 64,
+ "netD": "basic",
+ "norm": "none",
+ "init_type": "xavier",
+ }
+ },
+ "Loss": {
+ "lamb": 10,
+ "perceptual_lamb": 1,
+ "muvar_lamb": 50,
+ "style_lamb": 500
+ },
+ "Optimizer": {
+ "name": "Adam",
+ "learning_rate": {
+ "name": "lambda",
+ "lr": 0.0002,
+ "lr_decay_iters": 50
+ },
+ "beta1": 0.5,
+ "beta2": 0.999,
+ },
+ "Train": {
+ "batch_size_per_card": 8,
+ "num_workers_per_card": 4,
+ "dataset": {
+ "delimiter": "\t",
+ "data_dir": "/",
+ "label_file": "tmp/label.txt",
+ "transforms": [{
+ "DecodeImage": {
+ "to_rgb": True,
+ "to_np": False,
+ "channel_first": False
+ }
+ }, {
+ "NormalizeImage": {
+ "scale": 1. / 255.,
+ "mean": [0.485, 0.456, 0.406],
+ "std": [0.229, 0.224, 0.225],
+ "order": None
+ }
+ }, {
+ "ToCHWImage": None
+ }]
+ }
+ }
+ }
+ with open("config.yml", "w") as f:
+ yaml.dump(base_config, f)
+
+
+if __name__ == '__main__':
+ gen_config()
diff --git a/style_text_rec/utils/load_params.py b/style_text_rec/utils/load_params.py
new file mode 100644
index 00000000..be056136
--- /dev/null
+++ b/style_text_rec/utils/load_params.py
@@ -0,0 +1,27 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import paddle
+
+__all__ = ['load_dygraph_pretrain']
+
+
+def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
+ if not os.path.exists(path + '.pdparams'):
+ raise ValueError("Model pretrain path {} does not "
+ "exists.".format(path))
+ param_state_dict = paddle.load(path + '.pdparams')
+ model.set_state_dict(param_state_dict)
+ logger.info("load pretrained model from {}".format(path))
+ return
diff --git a/style_text_rec/utils/logging.py b/style_text_rec/utils/logging.py
new file mode 100644
index 00000000..f700fe26
--- /dev/null
+++ b/style_text_rec/utils/logging.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+import logging
+import functools
+import paddle.distributed as dist
+
+logger_initialized = {}
+
+
+@functools.lru_cache()
+def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
+ """Initialize and get a logger by name.
+ If the logger has not been initialized, this method will initialize the
+ logger by adding one or two handlers, otherwise the initialized logger will
+ be directly returned. During initialization, a StreamHandler will always be
+ added. If `log_file` is specified a FileHandler will also be added.
+ Args:
+ name (str): Logger name.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the logger.
+ log_level (int): The logger level. Note that only the process of
+ rank 0 is affected, and other processes will set the level to
+ "Error" thus be silent most of the time.
+ Returns:
+ logging.Logger: The expected logger.
+ """
+ logger = logging.getLogger(name)
+ if name in logger_initialized:
+ return logger
+ for logger_name in logger_initialized:
+ if name.startswith(logger_name):
+ return logger
+
+ formatter = logging.Formatter(
+ '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
+ datefmt="%Y/%m/%d %H:%M:%S")
+
+ stream_handler = logging.StreamHandler(stream=sys.stdout)
+ stream_handler.setFormatter(formatter)
+ logger.addHandler(stream_handler)
+ if log_file is not None and dist.get_rank() == 0:
+ log_file_folder = os.path.split(log_file)[0]
+ os.makedirs(log_file_folder, exist_ok=True)
+ file_handler = logging.FileHandler(log_file, 'a')
+ file_handler.setFormatter(formatter)
+ logger.addHandler(file_handler)
+ if dist.get_rank() == 0:
+ logger.setLevel(log_level)
+ else:
+ logger.setLevel(logging.ERROR)
+ logger_initialized[name] = True
+ return logger
diff --git a/style_text_rec/utils/math_functions.py b/style_text_rec/utils/math_functions.py
new file mode 100644
index 00000000..3dc8d916
--- /dev/null
+++ b/style_text_rec/utils/math_functions.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import paddle
+
+
+def compute_mean_covariance(img):
+ batch_size = img.shape[0]
+ channel_num = img.shape[1]
+ height = img.shape[2]
+ width = img.shape[3]
+ num_pixels = height * width
+
+ # batch_size * channel_num * 1 * 1
+ mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
+
+ # batch_size * channel_num * num_pixels
+ img_hat = img - mu.expand_as(img)
+ img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
+ # batch_size * num_pixels * channel_num
+ img_hat_transpose = img_hat.transpose([0, 2, 1])
+ # batch_size * channel_num * channel_num
+ covariance = paddle.bmm(img_hat, img_hat_transpose)
+ covariance = covariance / num_pixels
+
+ return mu, covariance
+
+
+def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
+ eps = 1e-5
+ intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
+ union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
+ y_pred_cls * training_mask) + eps
+ loss = 1. - (2 * intersection / union)
+ return loss
diff --git a/style_text_rec/utils/sys_funcs.py b/style_text_rec/utils/sys_funcs.py
new file mode 100644
index 00000000..203d91d8
--- /dev/null
+++ b/style_text_rec/utils/sys_funcs.py
@@ -0,0 +1,67 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import os
+import errno
+import paddle
+
+
+def get_check_global_params(mode):
+ check_params = [
+ 'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
+ 'character_type', 'loss_type'
+ ]
+ if mode == "train_eval":
+ check_params = check_params + [
+ 'train_batch_size_per_card', 'test_batch_size_per_card'
+ ]
+ elif mode == "test":
+ check_params = check_params + ['test_batch_size_per_card']
+ return check_params
+
+
+def check_gpu(use_gpu):
+ """
+ Log error and exit when set use_gpu=true in paddlepaddle
+ cpu version.
+ """
+ err = "Config use_gpu cannot be set as true while you are " \
+ "using paddlepaddle cpu version ! \nPlease try: \n" \
+ "\t1. Install paddlepaddle-gpu to run model on GPU \n" \
+ "\t2. Set use_gpu as false in config file to run " \
+ "model on CPU"
+ if use_gpu:
+ try:
+ if not paddle.is_compiled_with_cuda():
+ print(err)
+ sys.exit(1)
+ except:
+ print("Fail to check gpu state.")
+ sys.exit(1)
+
+
+def _mkdir_if_not_exist(path, logger):
+ """
+ mkdir if not exists, ignore the exception when multiprocess mkdir together
+ """
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path)
+ except OSError as e:
+ if e.errno == errno.EEXIST and os.path.isdir(path):
+ logger.warning(
+ 'be happy if some process has already created {}'.format(
+ path))
+ else:
+ raise OSError('Failed to mkdir {}'.format(path))