fix conflicts
|
@ -0,0 +1,106 @@
|
|||
### Quick Start
|
||||
|
||||
`Style-Text` is an improvement of the SRNet network proposed in Baidu's self-developed text editing algorithm "Editing Text in the Wild". It is different from the commonly used GAN methods. This tool decomposes the text synthesis task into three sub-modules to improve the effect of synthetic data: text style transfer module, background extraction module and fusion module.
|
||||
|
||||
The following figure shows some example results. In addition, the actual `nameplate text recognition` scene and `the Korean text recognition` scene verify the effectiveness of the synthesis tool, as follows.
|
||||
|
||||
|
||||
#### Preparation
|
||||
|
||||
1. Please refer the [QUICK INSTALLATION](../doc/doc_en/installation_en.md) to install PaddlePaddle. Python3 environment is strongly recommended.
|
||||
2. Download the pretrained models and unzip:
|
||||
|
||||
```bash
|
||||
cd tools/style_text_rec
|
||||
wget /path/to/style_text_models.zip
|
||||
unzip style_text_models.zip
|
||||
```
|
||||
|
||||
You can dowload models [here](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip). If you save the model files in other folders, please edit the three model paths in `configs/config.yml`:
|
||||
|
||||
```
|
||||
bg_generator:
|
||||
pretrain: style_text_rec/bg_generator
|
||||
...
|
||||
text_generator:
|
||||
pretrain: style_text_models/text_generator
|
||||
...
|
||||
fusion_generator:
|
||||
pretrain: style_text_models/fusion_generator
|
||||
```
|
||||
|
||||
|
||||
|
||||
#### Demo
|
||||
|
||||
1. You can use the following commands to run a demo:
|
||||
|
||||
```bash
|
||||
python -m tools.synth_image -c configs/config.yml
|
||||
```
|
||||
|
||||
2. The results are `fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg` as shown in the figure above. Above them:
|
||||
* `fake_text.jpg` is the generated image with the same font style as `Style Input`;
|
||||
* `fake_bg.jpg` is the generated image of `Style Input` after removing foreground.
|
||||
* `fake_fusion.jpg` is the final result, that is synthesised by `fake_text.jpg` and `fake_bg.jpg`.
|
||||
|
||||
3. If want to generate image by other `Style Input` or `Text Input`, you can modify the `tools/synth_image.py`:
|
||||
* `img = cv2.imread("examples/style_images/1.jpg")`: the path of `Style Input`;
|
||||
* `corpus = "PaddleOCR"`: the `Text Input`;
|
||||
* Notice:modify the language option(`language = "en"`) to adapt `Text Input`, that support `en`, `ch`, `ko`.
|
||||
|
||||
4. We also provide `batch_synth_images` mothod, that can combine corpus and pictures in pairs to generate a batch of data.
|
||||
|
||||
### Advanced Usage
|
||||
|
||||
#### Components
|
||||
|
||||
`Style Text Rec` mainly contains the following components:
|
||||
|
||||
* `style_samplers`: It can sample `Style Input` from a dataset. Now, We only provide `DatasetSampler`.
|
||||
|
||||
* `corpus_generators`: It can generate corpus. Now, wo only provide two `corpus_generators`:
|
||||
* `EnNumCorpus`: It can generate a random string according to a given length, including uppercase and lowercase English letters, numbers and spaces.
|
||||
* `FileCorpus`: It can read a text file and randomly return the words in it.
|
||||
|
||||
* `text_drawers`: It can generate `Text Input`(text picture in standard font according to the input corpus). Note that when using, you have to modify the language information according to the corpus.
|
||||
|
||||
* `predictors`: It can call the deep learning model to generate new data based on the `Style Input` and `Text Input`.
|
||||
|
||||
* `writers`: It can write the generated pictures(`fake_bg.jpg`, `fake_text.jpg` and `fake_fusion.jpg`) and label information to the disk.
|
||||
|
||||
* `synthesisers`: It can call the all modules to complete the work.
|
||||
|
||||
### Generate Dataset
|
||||
|
||||
Before the start, you need to prepare some data as material.
|
||||
First, you should have the style reference data for synthesis tasks, which are generally used as datasets for OCR recognition tasks.
|
||||
|
||||
1. The referenced dataset can be specifed in `configs/dataset_config.yml`:
|
||||
* `StyleSampler`:
|
||||
* `method`: The method of `StyleSampler`.
|
||||
* `image_home`: The directory of pictures.
|
||||
* `label_file`: The list of pictures path if `with_label` is `false`, otherwise, the label file path.
|
||||
* `with_label`: The `label_file` is label file or not.
|
||||
|
||||
* `CorpusGenerator`:
|
||||
* `method`: The mothod of `CorpusGenerator`. If `FileCorpus` used, you need modify `corpus_file` and `language` accordingly, if `EnNumCorpus`, other configurations is not needed.
|
||||
* `language`: The language of the corpus. Needed if method is not `EnNumCorpus`.
|
||||
* `corpus_file`: The corpus file path. Needed if method is not `EnNumCorpus`.
|
||||
|
||||
2. You can run the following command to start synthesis task:
|
||||
|
||||
``` bash
|
||||
python -m tools.synth_dataset.py -c configs/dataset_config.yml
|
||||
```
|
||||
|
||||
3. You can using the following command to start multiple synthesis tasks in a multi-threaded manner, which needed to specifying tags by `-t`:
|
||||
|
||||
```bash
|
||||
python -m tools.synth_dataset.py -t 0 -c configs/dataset_config.yml
|
||||
python -m tools.synth_dataset.py -t 1 -c configs/dataset_config.yml
|
||||
```
|
||||
|
||||
### OCR Recognition Training
|
||||
|
||||
After completing the above operations, you can get the synthetic data set for OCR recognition. Next, please complete the training by refering to [OCR Recognition Document](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition. md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83).
|
|
@ -0,0 +1,164 @@
|
|||
## Style Text Rec
|
||||
|
||||
### 目录
|
||||
- [工具简介](#工具简介)
|
||||
- [环境配置](#环境配置)
|
||||
- [快速上手](#快速上手)
|
||||
- [高级使用](#高级使用)
|
||||
- [应用示例](#应用示例)
|
||||
|
||||
### 工具简介
|
||||
<div align="center">
|
||||
<img src="doc/images/3.png" width="800">
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<img src="doc/images/1.png" width="600">
|
||||
</div>
|
||||
|
||||
Style-Text数据合成工具是基于百度自研的文本编辑算法《Editing Text in the Wild》https://arxiv.org/abs/1908.03047
|
||||
不同于常用的基于GAN的数据合成工具,Style-Text主要框架包括:1.文本前景风格迁移模块 2.背景抽取模块 3.融合模块。经过这样三步,就可以迅速实现图片文字风格迁移。下图是一些该数据合成工具效果图。
|
||||
|
||||
<div align="center">
|
||||
<img src="doc/images/2.png" width="1000">
|
||||
</div>
|
||||
|
||||
### 环境配置
|
||||
|
||||
1. 参考[快速安装](../doc/doc_ch/installation.md),安装PaddleOCR。
|
||||
2. 进入`style_text_rec`目录,下载模型,并解压:
|
||||
|
||||
```bash
|
||||
cd style_text_rec
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/style_text_models.zip
|
||||
unzip style_text_models.zip
|
||||
```
|
||||
|
||||
如果您将模型保存再其他位置,请在`configs/config.yml`中修改模型文件的地址,修改时需要同时修改这三个配置:
|
||||
|
||||
```
|
||||
bg_generator:
|
||||
pretrain: style_text_models/bg_generator
|
||||
...
|
||||
text_generator:
|
||||
pretrain: style_text_models/text_generator
|
||||
...
|
||||
fusion_generator:
|
||||
pretrain: style_text_models/fusion_generator
|
||||
```
|
||||
|
||||
### 快速上手
|
||||
|
||||
1. 运行tools/synth_image,生成示例图片:
|
||||
|
||||
```python
|
||||
python3 -m tools.synth_image -c configs/config.yml
|
||||
```
|
||||
|
||||
1. 运行后,会生成`fake_busion.jpg`,即为最终结果。
|
||||
<div align="center">
|
||||
<img src="doc/images/4.jpg" width="300">
|
||||
</div>
|
||||
除此之外,程序还会生成并保存中间结果:
|
||||
* `fake_bg.jpg`:为风格参考图去掉文字后的背景;
|
||||
* `fake_text.jpg`:是用提供的字符串,仿照风格参考图中文字的风格,生成在灰色背景上的文字图片。
|
||||
|
||||
2. 如果您想尝试其他风格图像和文字的效果,可以添加style_image,text_corpus和language参数:
|
||||
```python
|
||||
python3 -m tools.synth_image -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
|
||||
```
|
||||
* 注意:语言选项和语料相对应,目前我们支持英文、简体中文和韩语。
|
||||
|
||||
3. 在`tools/synth_image.py`中,我们还提供了一个`batch_synth_images`方法,可以两两组合语料和图片,批量生成一批数据。
|
||||
|
||||
### 高级使用
|
||||
|
||||
在开始合成数据集前,需要准备一些素材。
|
||||
|
||||
首先,需要风格图片作为合成图片的参考依据,这些数据可以是用作训练OCR识别模型的数据集。本例中使用带有标注文件的数据集作为风格图片.
|
||||
|
||||
1. 在`configs/dataset_config.yml`中配置输入数据路径。
|
||||
* `StyleSampler`:
|
||||
* `method`:使用的风格图片采样方法;
|
||||
* `image_home`:风格图片目录;
|
||||
* `label_file`:风格图片路径列表文件,如果所用数据集有label,则label_file为label文件路径;
|
||||
* `with_label`:标志`label_file`是否为label文件。
|
||||
* `CorpusGenerator`:
|
||||
* `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
|
||||
* `language`:语料的语种;
|
||||
* `corpus_file`: 语料文件路径。
|
||||
|
||||
我们提供了一批[样例图](https://paddleocr.bj.bcebos.com/dygraph_v2.0/style_text/chkoen_5w.tar)供您试用,下面给出了一些示例:
|
||||
<div align="center">
|
||||
<img src="doc/images/5.png" width="800">
|
||||
</div>
|
||||
2. 运行`tools/synth_dataset`合成数据:
|
||||
|
||||
``` bash
|
||||
python -m tools.synth_dataset -c configs/dataset_config.yml
|
||||
```
|
||||
|
||||
3. 如果您想使用并行方式来快速合成数据,可以通过启动多个进程,在启动时需要指定不同的`tag`(`-t`),如下所示:
|
||||
|
||||
```bash
|
||||
python3 -m tools.synth_dataset -t 0 -c configs/dataset_config.yml
|
||||
python3 -m tools.synth_dataset -t 1 -c configs/dataset_config.yml
|
||||
```
|
||||
|
||||
|
||||
### 应用示例
|
||||
|
||||
在完成上述操作后,即可得到用于OCR识别的合成数据集,下面给出了一些数据集生成的示例:
|
||||
<div align="center">
|
||||
<img src="doc/images/6.png" width="800">
|
||||
</div>
|
||||
请您参考[OCR识别文档](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/recognition.md#%E5%90%AF%E5%8A%A8%E8%AE%AD%E7%BB%83),完成训练。
|
||||
|
||||
下面展示了一些使用合成数据训练的效果:
|
||||
|
||||
| 场景 | 字符 | 原始数据 | 测试数据 | 只使用原始数据的识别准确率 | 新增合成数据 | 使用合成数据识别准确率 | 指标提升 |
|
||||
| -------- | ---------- | -------- | -------- | -------------------------- | ------------ | ---------------------- | -------- |
|
||||
| 金属表面 | 英文和数字 | 2203 | 650 | 0.5938 | 20000 | 0.7546 | 16% |
|
||||
| 随机背景 | 韩语 | 5631 | 1230 | 0.3012 | 100000 | 0.5057 | 20% |
|
||||
|
||||
### 项目结构
|
||||
```
|
||||
style_text_rec
|
||||
|-- arch
|
||||
| |-- base_module.py
|
||||
| |-- decoder.py
|
||||
| |-- encoder.py
|
||||
| |-- spectral_norm.py
|
||||
| `-- style_text_rec.py
|
||||
|-- configs
|
||||
| |-- config.yml
|
||||
| `-- dataset_config.yml
|
||||
|-- engine
|
||||
| |-- corpus_generators.py
|
||||
| |-- predictors.py
|
||||
| |-- style_samplers.py
|
||||
| |-- synthesisers.py
|
||||
| |-- text_drawers.py
|
||||
| `-- writers.py
|
||||
|-- examples
|
||||
| |-- corpus
|
||||
| | `-- example.txt
|
||||
| |-- image_list.txt
|
||||
| `-- style_images
|
||||
| |-- 1.jpg
|
||||
| `-- 2.jpg
|
||||
|-- fonts
|
||||
| |-- ch_standard.ttf
|
||||
| |-- en_standard.ttf
|
||||
| `-- ko_standard.ttf
|
||||
|-- tools
|
||||
| |-- __init__.py
|
||||
| |-- synth_dataset.py
|
||||
| `-- synth_image.py
|
||||
`-- utils
|
||||
|-- config.py
|
||||
|-- load_params.py
|
||||
|-- logging.py
|
||||
|-- math_functions.py
|
||||
`-- sys_funcs.py
|
||||
```
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from arch.spectral_norm import spectral_norm
|
||||
|
||||
|
||||
class CBN(nn.Layer):
|
||||
def __init__(self,
|
||||
name,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
use_bias=False,
|
||||
norm_layer=None,
|
||||
act=None,
|
||||
act_attr=None):
|
||||
super(CBN, self).__init__()
|
||||
if use_bias:
|
||||
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||
else:
|
||||
bias_attr = None
|
||||
self._conv = paddle.nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||
bias_attr=bias_attr)
|
||||
if norm_layer:
|
||||
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||
num_features=out_channels, name=name + "_bn")
|
||||
else:
|
||||
self._norm_layer = None
|
||||
if act:
|
||||
if act_attr:
|
||||
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||
name=name + "_" + act)
|
||||
else:
|
||||
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||
else:
|
||||
self._act = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self._conv(x)
|
||||
if self._norm_layer:
|
||||
out = self._norm_layer(out)
|
||||
if self._act:
|
||||
out = self._act(out)
|
||||
return out
|
||||
|
||||
|
||||
class SNConv(nn.Layer):
|
||||
def __init__(self,
|
||||
name,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
use_bias=False,
|
||||
norm_layer=None,
|
||||
act=None,
|
||||
act_attr=None):
|
||||
super(SNConv, self).__init__()
|
||||
if use_bias:
|
||||
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||
else:
|
||||
bias_attr = None
|
||||
self._sn_conv = spectral_norm(
|
||||
paddle.nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||
bias_attr=bias_attr))
|
||||
if norm_layer:
|
||||
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||
num_features=out_channels, name=name + "_bn")
|
||||
else:
|
||||
self._norm_layer = None
|
||||
if act:
|
||||
if act_attr:
|
||||
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||
name=name + "_" + act)
|
||||
else:
|
||||
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||
else:
|
||||
self._act = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self._sn_conv(x)
|
||||
if self._norm_layer:
|
||||
out = self._norm_layer(out)
|
||||
if self._act:
|
||||
out = self._act(out)
|
||||
return out
|
||||
|
||||
|
||||
class SNConvTranspose(nn.Layer):
|
||||
def __init__(self,
|
||||
name,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
use_bias=False,
|
||||
norm_layer=None,
|
||||
act=None,
|
||||
act_attr=None):
|
||||
super(SNConvTranspose, self).__init__()
|
||||
if use_bias:
|
||||
bias_attr = paddle.ParamAttr(name=name + "_bias")
|
||||
else:
|
||||
bias_attr = None
|
||||
self._sn_conv_transpose = spectral_norm(
|
||||
paddle.nn.Conv2DTranspose(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
output_padding=output_padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
weight_attr=paddle.ParamAttr(name=name + "_weights"),
|
||||
bias_attr=bias_attr))
|
||||
if norm_layer:
|
||||
self._norm_layer = getattr(paddle.nn, norm_layer)(
|
||||
num_features=out_channels, name=name + "_bn")
|
||||
else:
|
||||
self._norm_layer = None
|
||||
if act:
|
||||
if act_attr:
|
||||
self._act = getattr(paddle.nn, act)(**act_attr,
|
||||
name=name + "_" + act)
|
||||
else:
|
||||
self._act = getattr(paddle.nn, act)(name=name + "_" + act)
|
||||
else:
|
||||
self._act = None
|
||||
|
||||
def forward(self, x):
|
||||
out = self._sn_conv_transpose(x)
|
||||
if self._norm_layer:
|
||||
out = self._norm_layer(out)
|
||||
if self._act:
|
||||
out = self._act(out)
|
||||
return out
|
||||
|
||||
|
||||
class MiddleNet(nn.Layer):
|
||||
def __init__(self, name, in_channels, mid_channels, out_channels,
|
||||
use_bias):
|
||||
super(MiddleNet, self).__init__()
|
||||
self._sn_conv1 = SNConv(
|
||||
name=name + "_sn_conv1",
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=None,
|
||||
act=None)
|
||||
self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
|
||||
self._sn_conv2 = SNConv(
|
||||
name=name + "_sn_conv2",
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=3,
|
||||
use_bias=use_bias)
|
||||
self._sn_conv3 = SNConv(
|
||||
name=name + "_sn_conv3",
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
use_bias=use_bias)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
sn_conv1 = self._sn_conv1.forward(x)
|
||||
pad_2d = self._pad2d.forward(sn_conv1)
|
||||
sn_conv2 = self._sn_conv2.forward(pad_2d)
|
||||
sn_conv3 = self._sn_conv3.forward(sn_conv2)
|
||||
return sn_conv3
|
||||
|
||||
|
||||
class ResBlock(nn.Layer):
|
||||
def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
|
||||
use_bias):
|
||||
super(ResBlock, self).__init__()
|
||||
if use_dilation:
|
||||
padding_mat = [1, 1, 1, 1]
|
||||
else:
|
||||
padding_mat = [0, 0, 0, 0]
|
||||
self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
|
||||
|
||||
self._sn_conv1 = SNConv(
|
||||
name=name + "_sn_conv1",
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
norm_layer=norm_layer,
|
||||
use_bias=use_bias,
|
||||
act="ReLU",
|
||||
act_attr=None)
|
||||
if use_dropout:
|
||||
self._dropout = nn.Dropout(0.5)
|
||||
else:
|
||||
self._dropout = None
|
||||
self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||
self._sn_conv2 = SNConv(
|
||||
name=name + "_sn_conv2",
|
||||
in_channels=channels,
|
||||
out_channels=channels,
|
||||
kernel_size=3,
|
||||
norm_layer=norm_layer,
|
||||
use_bias=use_bias,
|
||||
act="ReLU",
|
||||
act_attr=None)
|
||||
|
||||
def forward(self, x):
|
||||
pad1 = self._pad1.forward(x)
|
||||
sn_conv1 = self._sn_conv1.forward(pad1)
|
||||
pad2 = self._pad2.forward(sn_conv1)
|
||||
sn_conv2 = self._sn_conv2.forward(pad2)
|
||||
return sn_conv2 + x
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
||||
|
||||
|
||||
class Decoder(nn.Layer):
|
||||
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||
act, act_attr, conv_block_dropout, conv_block_num,
|
||||
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||
super(Decoder, self).__init__()
|
||||
conv_blocks = []
|
||||
for i in range(conv_block_num):
|
||||
conv_blocks.append(
|
||||
ResBlock(
|
||||
name="{}_conv_block_{}".format(name, i),
|
||||
channels=encode_dim * 8,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=conv_block_dropout,
|
||||
use_dilation=conv_block_dilation,
|
||||
use_bias=use_bias))
|
||||
self.conv_blocks = nn.Sequential(*conv_blocks)
|
||||
self._up1 = SNConvTranspose(
|
||||
name=name + "_up1",
|
||||
in_channels=encode_dim * 8,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up2 = SNConvTranspose(
|
||||
name=name + "_up2",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up3 = SNConvTranspose(
|
||||
name=name + "_up3",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||
self._out_conv = SNConv(
|
||||
name=name + "_out_conv",
|
||||
in_channels=encode_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
use_bias=use_bias,
|
||||
norm_layer=None,
|
||||
act=out_conv_act,
|
||||
act_attr=out_conv_act_attr)
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, (list, tuple)):
|
||||
x = paddle.concat(x, axis=1)
|
||||
output_dict = dict()
|
||||
output_dict["conv_blocks"] = self.conv_blocks.forward(x)
|
||||
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||
output_dict["up2"] = self._up2.forward(output_dict["up1"])
|
||||
output_dict["up3"] = self._up3.forward(output_dict["up2"])
|
||||
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||
return output_dict
|
||||
|
||||
|
||||
class DecoderUnet(nn.Layer):
|
||||
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||
act, act_attr, conv_block_dropout, conv_block_num,
|
||||
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||
super(DecoderUnet, self).__init__()
|
||||
conv_blocks = []
|
||||
for i in range(conv_block_num):
|
||||
conv_blocks.append(
|
||||
ResBlock(
|
||||
name="{}_conv_block_{}".format(name, i),
|
||||
channels=encode_dim * 8,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=conv_block_dropout,
|
||||
use_dilation=conv_block_dilation,
|
||||
use_bias=use_bias))
|
||||
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||
self._up1 = SNConvTranspose(
|
||||
name=name + "_up1",
|
||||
in_channels=encode_dim * 8,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up2 = SNConvTranspose(
|
||||
name=name + "_up2",
|
||||
in_channels=encode_dim * 8,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up3 = SNConvTranspose(
|
||||
name=name + "_up3",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||
self._out_conv = SNConv(
|
||||
name=name + "_out_conv",
|
||||
in_channels=encode_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
use_bias=use_bias,
|
||||
norm_layer=None,
|
||||
act=out_conv_act,
|
||||
act_attr=out_conv_act_attr)
|
||||
|
||||
def forward(self, x, y, feature2, feature1):
|
||||
output_dict = dict()
|
||||
output_dict["conv_blocks"] = self._conv_blocks(
|
||||
paddle.concat(
|
||||
(x, y), axis=1))
|
||||
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||
output_dict["up2"] = self._up2.forward(
|
||||
paddle.concat(
|
||||
(output_dict["up1"], feature2), axis=1))
|
||||
output_dict["up3"] = self._up3.forward(
|
||||
paddle.concat(
|
||||
(output_dict["up2"], feature1), axis=1))
|
||||
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||
return output_dict
|
||||
|
||||
|
||||
class SingleDecoder(nn.Layer):
|
||||
def __init__(self, name, encode_dim, out_channels, use_bias, norm_layer,
|
||||
act, act_attr, conv_block_dropout, conv_block_num,
|
||||
conv_block_dilation, out_conv_act, out_conv_act_attr):
|
||||
super(SingleDecoder, self).__init__()
|
||||
conv_blocks = []
|
||||
for i in range(conv_block_num):
|
||||
conv_blocks.append(
|
||||
ResBlock(
|
||||
name="{}_conv_block_{}".format(name, i),
|
||||
channels=encode_dim * 4,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=conv_block_dropout,
|
||||
use_dilation=conv_block_dilation,
|
||||
use_bias=use_bias))
|
||||
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||
self._up1 = SNConvTranspose(
|
||||
name=name + "_up1",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up2 = SNConvTranspose(
|
||||
name=name + "_up2",
|
||||
in_channels=encode_dim * 8,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up3 = SNConvTranspose(
|
||||
name=name + "_up3",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._pad2d = paddle.nn.Pad2D([1, 1, 1, 1], mode="replicate")
|
||||
self._out_conv = SNConv(
|
||||
name=name + "_out_conv",
|
||||
in_channels=encode_dim,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
use_bias=use_bias,
|
||||
norm_layer=None,
|
||||
act=out_conv_act,
|
||||
act_attr=out_conv_act_attr)
|
||||
|
||||
def forward(self, x, feature2, feature1):
|
||||
output_dict = dict()
|
||||
output_dict["conv_blocks"] = self._conv_blocks.forward(x)
|
||||
output_dict["up1"] = self._up1.forward(output_dict["conv_blocks"])
|
||||
output_dict["up2"] = self._up2.forward(
|
||||
paddle.concat(
|
||||
(output_dict["up1"], feature2), axis=1))
|
||||
output_dict["up3"] = self._up3.forward(
|
||||
paddle.concat(
|
||||
(output_dict["up2"], feature1), axis=1))
|
||||
output_dict["pad2d"] = self._pad2d.forward(output_dict["up3"])
|
||||
output_dict["out_conv"] = self._out_conv.forward(output_dict["pad2d"])
|
||||
return output_dict
|
|
@ -0,0 +1,186 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from arch.base_module import SNConv, SNConvTranspose, ResBlock
|
||||
|
||||
|
||||
class Encoder(nn.Layer):
|
||||
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
||||
act, act_attr, conv_block_dropout, conv_block_num,
|
||||
conv_block_dilation):
|
||||
super(Encoder, self).__init__()
|
||||
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
||||
self._in_conv = SNConv(
|
||||
name=name + "_in_conv",
|
||||
in_channels=in_channels,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=7,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down1 = SNConv(
|
||||
name=name + "_down1",
|
||||
in_channels=encode_dim,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down2 = SNConv(
|
||||
name=name + "_down2",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down3 = SNConv(
|
||||
name=name + "_down3",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
conv_blocks = []
|
||||
for i in range(conv_block_num):
|
||||
conv_blocks.append(
|
||||
ResBlock(
|
||||
name="{}_conv_block_{}".format(name, i),
|
||||
channels=encode_dim * 4,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=conv_block_dropout,
|
||||
use_dilation=conv_block_dilation,
|
||||
use_bias=use_bias))
|
||||
self._conv_blocks = nn.Sequential(*conv_blocks)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = dict()
|
||||
x = self._pad2d(x)
|
||||
out_dict["in_conv"] = self._in_conv.forward(x)
|
||||
out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
|
||||
out_dict["down2"] = self._down2.forward(out_dict["down1"])
|
||||
out_dict["down3"] = self._down3.forward(out_dict["down2"])
|
||||
out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
|
||||
return out_dict
|
||||
|
||||
|
||||
class EncoderUnet(nn.Layer):
|
||||
def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
|
||||
act, act_attr):
|
||||
super(EncoderUnet, self).__init__()
|
||||
self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
|
||||
self._in_conv = SNConv(
|
||||
name=name + "_in_conv",
|
||||
in_channels=in_channels,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=7,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down1 = SNConv(
|
||||
name=name + "_down1",
|
||||
in_channels=encode_dim,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down2 = SNConv(
|
||||
name=name + "_down2",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down3 = SNConv(
|
||||
name=name + "_down3",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._down4 = SNConv(
|
||||
name=name + "_down4",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up1 = SNConvTranspose(
|
||||
name=name + "_up1",
|
||||
in_channels=encode_dim * 2,
|
||||
out_channels=encode_dim * 2,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
self._up2 = SNConvTranspose(
|
||||
name=name + "_up2",
|
||||
in_channels=encode_dim * 4,
|
||||
out_channels=encode_dim * 4,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act=act,
|
||||
act_attr=act_attr)
|
||||
|
||||
def forward(self, x):
|
||||
output_dict = dict()
|
||||
x = self._pad2d(x)
|
||||
output_dict['in_conv'] = self._in_conv.forward(x)
|
||||
output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
|
||||
output_dict['down2'] = self._down2.forward(output_dict['down1'])
|
||||
output_dict['down3'] = self._down3.forward(output_dict['down2'])
|
||||
output_dict['down4'] = self._down4.forward(output_dict['down3'])
|
||||
output_dict['up1'] = self._up1.forward(output_dict['down4'])
|
||||
output_dict['up2'] = self._up2.forward(
|
||||
paddle.concat(
|
||||
(output_dict['down3'], output_dict['up1']), axis=1))
|
||||
output_dict['concat'] = paddle.concat(
|
||||
(output_dict['down2'], output_dict['up2']), axis=1)
|
||||
return output_dict
|
|
@ -0,0 +1,150 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
|
||||
def normal_(x, mean=0., std=1.):
|
||||
temp_value = paddle.normal(mean, std, shape=x.shape)
|
||||
x.set_value(temp_value)
|
||||
return x
|
||||
|
||||
|
||||
class SpectralNorm(object):
|
||||
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||
self.name = name
|
||||
self.dim = dim
|
||||
if n_power_iterations <= 0:
|
||||
raise ValueError('Expected n_power_iterations to be positive, but '
|
||||
'got n_power_iterations={}'.format(
|
||||
n_power_iterations))
|
||||
self.n_power_iterations = n_power_iterations
|
||||
self.eps = eps
|
||||
|
||||
def reshape_weight_to_matrix(self, weight):
|
||||
weight_mat = weight
|
||||
if self.dim != 0:
|
||||
# transpose dim to front
|
||||
weight_mat = weight_mat.transpose([
|
||||
self.dim,
|
||||
* [d for d in range(weight_mat.dim()) if d != self.dim]
|
||||
])
|
||||
|
||||
height = weight_mat.shape[0]
|
||||
|
||||
return weight_mat.reshape([height, -1])
|
||||
|
||||
def compute_weight(self, module, do_power_iteration):
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
u = getattr(module, self.name + '_u')
|
||||
v = getattr(module, self.name + '_v')
|
||||
weight_mat = self.reshape_weight_to_matrix(weight)
|
||||
|
||||
if do_power_iteration:
|
||||
with paddle.no_grad():
|
||||
for _ in range(self.n_power_iterations):
|
||||
v.set_value(
|
||||
F.normalize(
|
||||
paddle.matmul(
|
||||
weight_mat,
|
||||
u,
|
||||
transpose_x=True,
|
||||
transpose_y=False),
|
||||
axis=0,
|
||||
epsilon=self.eps, ))
|
||||
|
||||
u.set_value(
|
||||
F.normalize(
|
||||
paddle.matmul(weight_mat, v),
|
||||
axis=0,
|
||||
epsilon=self.eps, ))
|
||||
if self.n_power_iterations > 0:
|
||||
u = u.clone()
|
||||
v = v.clone()
|
||||
|
||||
sigma = paddle.dot(u, paddle.mv(weight_mat, v))
|
||||
weight = weight / sigma
|
||||
return weight
|
||||
|
||||
def remove(self, module):
|
||||
with paddle.no_grad():
|
||||
weight = self.compute_weight(module, do_power_iteration=False)
|
||||
delattr(module, self.name)
|
||||
delattr(module, self.name + '_u')
|
||||
delattr(module, self.name + '_v')
|
||||
delattr(module, self.name + '_orig')
|
||||
|
||||
module.add_parameter(self.name, weight.detach())
|
||||
|
||||
def __call__(self, module, inputs):
|
||||
setattr(
|
||||
module,
|
||||
self.name,
|
||||
self.compute_weight(
|
||||
module, do_power_iteration=module.training))
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name, n_power_iterations, dim, eps):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||
raise RuntimeError(
|
||||
"Cannot register two spectral_norm hooks on "
|
||||
"the same parameter {}".format(name))
|
||||
|
||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||
weight = module._parameters[name]
|
||||
|
||||
with paddle.no_grad():
|
||||
weight_mat = fn.reshape_weight_to_matrix(weight)
|
||||
h, w = weight_mat.shape
|
||||
|
||||
# randomly initialize u and v
|
||||
u = module.create_parameter([h])
|
||||
u = normal_(u, 0., 1.)
|
||||
v = module.create_parameter([w])
|
||||
v = normal_(v, 0., 1.)
|
||||
u = F.normalize(u, axis=0, epsilon=fn.eps)
|
||||
v = F.normalize(v, axis=0, epsilon=fn.eps)
|
||||
|
||||
# delete fn.name form parameters, otherwise you can not set attribute
|
||||
del module._parameters[fn.name]
|
||||
module.add_parameter(fn.name + "_orig", weight)
|
||||
# still need to assign weight back as fn.name because all sorts of
|
||||
# things may assume that it exists, e.g., when initializing weights.
|
||||
# However, we can't directly assign as it could be an Parameter and
|
||||
# gets added as a parameter. Instead, we register weight * 1.0 as a plain
|
||||
# attribute.
|
||||
setattr(module, fn.name, weight * 1.0)
|
||||
module.register_buffer(fn.name + "_u", u)
|
||||
module.register_buffer(fn.name + "_v", v)
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
return fn
|
||||
|
||||
|
||||
def spectral_norm(module,
|
||||
name='weight',
|
||||
n_power_iterations=1,
|
||||
eps=1e-12,
|
||||
dim=None):
|
||||
|
||||
if dim is None:
|
||||
if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
|
||||
nn.Conv3DTranspose, nn.Linear)):
|
||||
dim = 1
|
||||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
return module
|
|
@ -0,0 +1,285 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from arch.base_module import MiddleNet, ResBlock
|
||||
from arch.encoder import Encoder
|
||||
from arch.decoder import Decoder, DecoderUnet, SingleDecoder
|
||||
from utils.load_params import load_dygraph_pretrain
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
class StyleTextRec(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super(StyleTextRec, self).__init__()
|
||||
self.logger = get_logger()
|
||||
self.text_generator = TextGenerator(config["Predictor"][
|
||||
"text_generator"])
|
||||
self.bg_generator = BgGeneratorWithMask(config["Predictor"][
|
||||
"bg_generator"])
|
||||
self.fusion_generator = FusionGeneratorSimple(config["Predictor"][
|
||||
"fusion_generator"])
|
||||
bg_generator_pretrain = config["Predictor"]["bg_generator"]["pretrain"]
|
||||
text_generator_pretrain = config["Predictor"]["text_generator"][
|
||||
"pretrain"]
|
||||
fusion_generator_pretrain = config["Predictor"]["fusion_generator"][
|
||||
"pretrain"]
|
||||
load_dygraph_pretrain(
|
||||
self.bg_generator,
|
||||
self.logger,
|
||||
path=bg_generator_pretrain,
|
||||
load_static_weights=False)
|
||||
load_dygraph_pretrain(
|
||||
self.text_generator,
|
||||
self.logger,
|
||||
path=text_generator_pretrain,
|
||||
load_static_weights=False)
|
||||
load_dygraph_pretrain(
|
||||
self.fusion_generator,
|
||||
self.logger,
|
||||
path=fusion_generator_pretrain,
|
||||
load_static_weights=False)
|
||||
|
||||
def forward(self, style_input, text_input):
|
||||
text_gen_output = self.text_generator.forward(style_input, text_input)
|
||||
fake_text = text_gen_output["fake_text"]
|
||||
fake_sk = text_gen_output["fake_sk"]
|
||||
bg_gen_output = self.bg_generator.forward(style_input)
|
||||
bg_encode_feature = bg_gen_output["bg_encode_feature"]
|
||||
bg_decode_feature1 = bg_gen_output["bg_decode_feature1"]
|
||||
bg_decode_feature2 = bg_gen_output["bg_decode_feature2"]
|
||||
fake_bg = bg_gen_output["fake_bg"]
|
||||
|
||||
fusion_gen_output = self.fusion_generator.forward(fake_text, fake_bg)
|
||||
fake_fusion = fusion_gen_output["fake_fusion"]
|
||||
return {
|
||||
"fake_fusion": fake_fusion,
|
||||
"fake_text": fake_text,
|
||||
"fake_sk": fake_sk,
|
||||
"fake_bg": fake_bg,
|
||||
}
|
||||
|
||||
|
||||
class TextGenerator(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super(TextGenerator, self).__init__()
|
||||
name = config["module_name"]
|
||||
encode_dim = config["encode_dim"]
|
||||
norm_layer = config["norm_layer"]
|
||||
conv_block_dropout = config["conv_block_dropout"]
|
||||
conv_block_num = config["conv_block_num"]
|
||||
conv_block_dilation = config["conv_block_dilation"]
|
||||
if norm_layer == "InstanceNorm2D":
|
||||
use_bias = True
|
||||
else:
|
||||
use_bias = False
|
||||
self.encoder_text = Encoder(
|
||||
name=name + "_encoder_text",
|
||||
in_channels=3,
|
||||
encode_dim=encode_dim,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation)
|
||||
self.encoder_style = Encoder(
|
||||
name=name + "_encoder_style",
|
||||
in_channels=3,
|
||||
encode_dim=encode_dim,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation)
|
||||
self.decoder_text = Decoder(
|
||||
name=name + "_decoder_text",
|
||||
encode_dim=encode_dim,
|
||||
out_channels=int(encode_dim / 2),
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation,
|
||||
out_conv_act="Tanh",
|
||||
out_conv_act_attr=None)
|
||||
self.decoder_sk = Decoder(
|
||||
name=name + "_decoder_sk",
|
||||
encode_dim=encode_dim,
|
||||
out_channels=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation,
|
||||
out_conv_act="Sigmoid",
|
||||
out_conv_act_attr=None)
|
||||
|
||||
self.middle = MiddleNet(
|
||||
name=name + "_middle_net",
|
||||
in_channels=int(encode_dim / 2) + 1,
|
||||
mid_channels=encode_dim,
|
||||
out_channels=3,
|
||||
use_bias=use_bias)
|
||||
|
||||
def forward(self, style_input, text_input):
|
||||
style_feature = self.encoder_style.forward(style_input)["res_blocks"]
|
||||
text_feature = self.encoder_text.forward(text_input)["res_blocks"]
|
||||
fake_c_temp = self.decoder_text.forward([text_feature,
|
||||
style_feature])["out_conv"]
|
||||
fake_sk = self.decoder_sk.forward([text_feature,
|
||||
style_feature])["out_conv"]
|
||||
fake_text = self.middle(paddle.concat((fake_c_temp, fake_sk), axis=1))
|
||||
return {"fake_sk": fake_sk, "fake_text": fake_text}
|
||||
|
||||
|
||||
class BgGeneratorWithMask(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super(BgGeneratorWithMask, self).__init__()
|
||||
name = config["module_name"]
|
||||
encode_dim = config["encode_dim"]
|
||||
norm_layer = config["norm_layer"]
|
||||
conv_block_dropout = config["conv_block_dropout"]
|
||||
conv_block_num = config["conv_block_num"]
|
||||
conv_block_dilation = config["conv_block_dilation"]
|
||||
self.output_factor = config.get("output_factor", 1.0)
|
||||
|
||||
if norm_layer == "InstanceNorm2D":
|
||||
use_bias = True
|
||||
else:
|
||||
use_bias = False
|
||||
|
||||
self.encoder_bg = Encoder(
|
||||
name=name + "_encoder_bg",
|
||||
in_channels=3,
|
||||
encode_dim=encode_dim,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation)
|
||||
|
||||
self.decoder_bg = SingleDecoder(
|
||||
name=name + "_decoder_bg",
|
||||
encode_dim=encode_dim,
|
||||
out_channels=3,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation,
|
||||
out_conv_act="Tanh",
|
||||
out_conv_act_attr=None)
|
||||
|
||||
self.decoder_mask = Decoder(
|
||||
name=name + "_decoder_mask",
|
||||
encode_dim=encode_dim // 2,
|
||||
out_channels=1,
|
||||
use_bias=use_bias,
|
||||
norm_layer=norm_layer,
|
||||
act="ReLU",
|
||||
act_attr=None,
|
||||
conv_block_dropout=conv_block_dropout,
|
||||
conv_block_num=conv_block_num,
|
||||
conv_block_dilation=conv_block_dilation,
|
||||
out_conv_act="Sigmoid",
|
||||
out_conv_act_attr=None)
|
||||
|
||||
self.middle = MiddleNet(
|
||||
name=name + "_middle_net",
|
||||
in_channels=3 + 1,
|
||||
mid_channels=encode_dim,
|
||||
out_channels=3,
|
||||
use_bias=use_bias)
|
||||
|
||||
def forward(self, style_input):
|
||||
encode_bg_output = self.encoder_bg(style_input)
|
||||
decode_bg_output = self.decoder_bg(encode_bg_output["res_blocks"],
|
||||
encode_bg_output["down2"],
|
||||
encode_bg_output["down1"])
|
||||
|
||||
fake_c_temp = decode_bg_output["out_conv"]
|
||||
fake_bg_mask = self.decoder_mask.forward(encode_bg_output[
|
||||
"res_blocks"])["out_conv"]
|
||||
fake_bg = self.middle(
|
||||
paddle.concat(
|
||||
(fake_c_temp, fake_bg_mask), axis=1))
|
||||
return {
|
||||
"bg_encode_feature": encode_bg_output["res_blocks"],
|
||||
"bg_decode_feature1": decode_bg_output["up1"],
|
||||
"bg_decode_feature2": decode_bg_output["up2"],
|
||||
"fake_bg": fake_bg,
|
||||
"fake_bg_mask": fake_bg_mask,
|
||||
}
|
||||
|
||||
|
||||
class FusionGeneratorSimple(nn.Layer):
|
||||
def __init__(self, config):
|
||||
super(FusionGeneratorSimple, self).__init__()
|
||||
name = config["module_name"]
|
||||
encode_dim = config["encode_dim"]
|
||||
norm_layer = config["norm_layer"]
|
||||
conv_block_dropout = config["conv_block_dropout"]
|
||||
conv_block_dilation = config["conv_block_dilation"]
|
||||
if norm_layer == "InstanceNorm2D":
|
||||
use_bias = True
|
||||
else:
|
||||
use_bias = False
|
||||
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=6,
|
||||
out_channels=encode_dim,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=paddle.ParamAttr(name=name + "_conv_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
self._res_block = ResBlock(
|
||||
name="{}_conv_block".format(name),
|
||||
channels=encode_dim,
|
||||
norm_layer=norm_layer,
|
||||
use_dropout=conv_block_dropout,
|
||||
use_dilation=conv_block_dilation,
|
||||
use_bias=use_bias)
|
||||
|
||||
self._reduce_conv = nn.Conv2D(
|
||||
in_channels=encode_dim,
|
||||
out_channels=3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=paddle.ParamAttr(name=name + "_reduce_conv_weights"),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, fake_text, fake_bg):
|
||||
fake_concat = paddle.concat((fake_text, fake_bg), axis=1)
|
||||
fake_concat_tmp = self._conv(fake_concat)
|
||||
output_res = self._res_block(fake_concat_tmp)
|
||||
fake_fusion = self._reduce_conv(output_res)
|
||||
return {"fake_fusion": fake_fusion}
|
|
@ -0,0 +1,54 @@
|
|||
Global:
|
||||
output_num: 10
|
||||
output_dir: output_data
|
||||
use_gpu: false
|
||||
image_height: 32
|
||||
image_width: 320
|
||||
TextDrawer:
|
||||
fonts:
|
||||
en: fonts/en_standard.ttf
|
||||
ch: fonts/ch_standard.ttf
|
||||
ko: fonts/ko_standard.ttf
|
||||
Predictor:
|
||||
method: StyleTextRecPredictor
|
||||
algorithm: StyleTextRec
|
||||
scale: 0.00392156862745098
|
||||
mean:
|
||||
- 0.5
|
||||
- 0.5
|
||||
- 0.5
|
||||
std:
|
||||
- 0.5
|
||||
- 0.5
|
||||
- 0.5
|
||||
expand_result: false
|
||||
bg_generator:
|
||||
pretrain: style_text_models/bg_generator
|
||||
module_name: bg_generator
|
||||
generator_type: BgGeneratorWithMask
|
||||
encode_dim: 64
|
||||
norm_layer: null
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
output_factor: 1.05
|
||||
text_generator:
|
||||
pretrain: style_text_models/text_generator
|
||||
module_name: text_generator
|
||||
generator_type: TextGenerator
|
||||
encode_dim: 64
|
||||
norm_layer: InstanceNorm2D
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
fusion_generator:
|
||||
pretrain: style_text_models/fusion_generator
|
||||
module_name: fusion_generator
|
||||
generator_type: FusionGeneratorSimple
|
||||
encode_dim: 64
|
||||
norm_layer: null
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
Writer:
|
||||
method: SimpleWriter
|
|
@ -0,0 +1,64 @@
|
|||
Global:
|
||||
output_num: 10
|
||||
output_dir: output_data
|
||||
use_gpu: false
|
||||
image_height: 32
|
||||
image_width: 320
|
||||
standard_font: fonts/en_standard.ttf
|
||||
TextDrawer:
|
||||
fonts:
|
||||
en: fonts/en_standard.ttf
|
||||
ch: fonts/ch_standard.ttf
|
||||
ko: fonts/ko_standard.ttf
|
||||
StyleSampler:
|
||||
method: DatasetSampler
|
||||
image_home: examples
|
||||
label_file: examples/image_list.txt
|
||||
with_label: true
|
||||
CorpusGenerator:
|
||||
method: FileCorpus
|
||||
language: ch
|
||||
corpus_file: examples/corpus/example.txt
|
||||
Predictor:
|
||||
method: StyleTextRecPredictor
|
||||
algorithm: StyleTextRec
|
||||
scale: 0.00392156862745098
|
||||
mean:
|
||||
- 0.5
|
||||
- 0.5
|
||||
- 0.5
|
||||
std:
|
||||
- 0.5
|
||||
- 0.5
|
||||
- 0.5
|
||||
expand_result: false
|
||||
bg_generator:
|
||||
pretrain: models/style_text_rec/bg_generator
|
||||
module_name: bg_generator
|
||||
generator_type: BgGeneratorWithMask
|
||||
encode_dim: 64
|
||||
norm_layer: null
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
output_factor: 1.05
|
||||
text_generator:
|
||||
pretrain: models/style_text_rec/text_generator
|
||||
module_name: text_generator
|
||||
generator_type: TextGenerator
|
||||
encode_dim: 64
|
||||
norm_layer: InstanceNorm2D
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
fusion_generator:
|
||||
pretrain: models/style_text_rec/fusion_generator
|
||||
module_name: fusion_generator
|
||||
generator_type: FusionGeneratorSimple
|
||||
encode_dim: 64
|
||||
norm_layer: null
|
||||
conv_block_num: 4
|
||||
conv_block_dropout: false
|
||||
conv_block_dilation: true
|
||||
Writer:
|
||||
method: SimpleWriter
|
After Width: | Height: | Size: 168 KiB |
After Width: | Height: | Size: 201 KiB |
After Width: | Height: | Size: 68 KiB |
After Width: | Height: | Size: 2.2 KiB |
After Width: | Height: | Size: 122 KiB |
After Width: | Height: | Size: 125 KiB |
|
@ -0,0 +1,66 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import random
|
||||
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
class FileCorpus(object):
|
||||
def __init__(self, config):
|
||||
self.logger = get_logger()
|
||||
self.logger.info("using FileCorpus")
|
||||
|
||||
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
corpus_file = config["CorpusGenerator"]["corpus_file"]
|
||||
self.language = config["CorpusGenerator"]["language"]
|
||||
with open(corpus_file, 'r') as f:
|
||||
corpus_raw = f.read()
|
||||
self.corpus_list = corpus_raw.split("\n")[:-1]
|
||||
assert len(self.corpus_list) > 0
|
||||
random.shuffle(self.corpus_list)
|
||||
self.index = 0
|
||||
|
||||
def generate(self, corpus_length=0):
|
||||
if self.index >= len(self.corpus_list):
|
||||
self.index = 0
|
||||
random.shuffle(self.corpus_list)
|
||||
corpus = self.corpus_list[self.index]
|
||||
if corpus_length != 0:
|
||||
corpus = corpus[0:corpus_length]
|
||||
if corpus_length > len(corpus):
|
||||
self.logger.warning("generated corpus is shorter than expected.")
|
||||
self.index += 1
|
||||
return self.language, corpus
|
||||
|
||||
|
||||
class EnNumCorpus(object):
|
||||
def __init__(self, config):
|
||||
self.logger = get_logger()
|
||||
self.logger.info("using NumberCorpus")
|
||||
self.num_list = "0123456789"
|
||||
self.en_char_list = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
self.height = config["Global"]["image_height"]
|
||||
self.max_width = config["Global"]["image_width"]
|
||||
|
||||
def generate(self, corpus_length=0):
|
||||
corpus = ""
|
||||
if corpus_length == 0:
|
||||
corpus_length = random.randint(5, 15)
|
||||
for i in range(corpus_length):
|
||||
if random.random() < 0.2:
|
||||
corpus += "{}".format(random.choice(self.en_char_list))
|
||||
else:
|
||||
corpus += "{}".format(random.choice(self.num_list))
|
||||
return "en", corpus
|
|
@ -0,0 +1,115 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import cv2
|
||||
import math
|
||||
import paddle
|
||||
|
||||
from arch import style_text_rec
|
||||
from utils.sys_funcs import check_gpu
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
class StyleTextRecPredictor(object):
|
||||
def __init__(self, config):
|
||||
algorithm = config['Predictor']['algorithm']
|
||||
assert algorithm in ["StyleTextRec"
|
||||
], "Generator {} not supported.".format(algorithm)
|
||||
use_gpu = config["Global"]['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
self.logger = get_logger()
|
||||
self.generator = getattr(style_text_rec, algorithm)(config)
|
||||
self.height = config["Global"]["image_height"]
|
||||
self.width = config["Global"]["image_width"]
|
||||
self.scale = config["Predictor"]["scale"]
|
||||
self.mean = config["Predictor"]["mean"]
|
||||
self.std = config["Predictor"]["std"]
|
||||
self.expand_result = config["Predictor"]["expand_result"]
|
||||
|
||||
def predict(self, style_input, text_input):
|
||||
style_input = self.rep_style_input(style_input, text_input)
|
||||
tensor_style_input = self.preprocess(style_input)
|
||||
tensor_text_input = self.preprocess(text_input)
|
||||
style_text_result = self.generator.forward(tensor_style_input,
|
||||
tensor_text_input)
|
||||
fake_fusion = self.postprocess(style_text_result["fake_fusion"])
|
||||
fake_text = self.postprocess(style_text_result["fake_text"])
|
||||
fake_sk = self.postprocess(style_text_result["fake_sk"])
|
||||
fake_bg = self.postprocess(style_text_result["fake_bg"])
|
||||
bbox = self.get_text_boundary(fake_text)
|
||||
if bbox:
|
||||
left, right, top, bottom = bbox
|
||||
fake_fusion = fake_fusion[top:bottom, left:right, :]
|
||||
fake_text = fake_text[top:bottom, left:right, :]
|
||||
fake_sk = fake_sk[top:bottom, left:right, :]
|
||||
fake_bg = fake_bg[top:bottom, left:right, :]
|
||||
|
||||
# fake_fusion = self.crop_by_text(img_fake_fusion, img_fake_text)
|
||||
return {
|
||||
"fake_fusion": fake_fusion,
|
||||
"fake_text": fake_text,
|
||||
"fake_sk": fake_sk,
|
||||
"fake_bg": fake_bg,
|
||||
}
|
||||
|
||||
def preprocess(self, img):
|
||||
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
img_height, img_width, channel = img.shape
|
||||
assert channel == 3, "Please use an rgb image."
|
||||
ratio = img_width / float(img_height)
|
||||
if math.ceil(self.height * ratio) > self.width:
|
||||
resized_w = self.width
|
||||
else:
|
||||
resized_w = int(math.ceil(self.height * ratio))
|
||||
img = cv2.resize(img, (resized_w, self.height))
|
||||
|
||||
new_img = np.zeros([self.height, self.width, 3]).astype('float32')
|
||||
new_img[:, 0:resized_w, :] = img
|
||||
img = new_img.transpose((2, 0, 1))
|
||||
img = img[np.newaxis, :, :, :]
|
||||
return paddle.to_tensor(img)
|
||||
|
||||
def postprocess(self, tensor):
|
||||
img = tensor.numpy()[0]
|
||||
img = img.transpose((1, 2, 0))
|
||||
img = (img * self.std + self.mean) / self.scale
|
||||
img = np.maximum(img, 0.0)
|
||||
img = np.minimum(img, 255.0)
|
||||
img = img.astype('uint8')
|
||||
return img
|
||||
|
||||
def rep_style_input(self, style_input, text_input):
|
||||
rep_num = int(1.2 * (text_input.shape[1] / text_input.shape[0]) /
|
||||
(style_input.shape[1] / style_input.shape[0])) + 1
|
||||
style_input = np.tile(style_input, reps=[1, rep_num, 1])
|
||||
max_width = int(self.width / self.height * style_input.shape[0])
|
||||
style_input = style_input[:, :max_width, :]
|
||||
return style_input
|
||||
|
||||
def get_text_boundary(self, text_img):
|
||||
img_height = text_img.shape[0]
|
||||
img_width = text_img.shape[1]
|
||||
bounder = 3
|
||||
text_canny_img = cv2.Canny(text_img, 10, 20)
|
||||
edge_num_h = text_canny_img.sum(axis=0)
|
||||
no_zero_list_h = np.where(edge_num_h > 0)[0]
|
||||
edge_num_w = text_canny_img.sum(axis=1)
|
||||
no_zero_list_w = np.where(edge_num_w > 0)[0]
|
||||
if len(no_zero_list_h) == 0 or len(no_zero_list_w) == 0:
|
||||
return None
|
||||
left = max(no_zero_list_h[0] - bounder, 0)
|
||||
right = min(no_zero_list_h[-1] + bounder, img_width)
|
||||
top = max(no_zero_list_w[0] - bounder, 0)
|
||||
bottom = min(no_zero_list_w[-1] + bounder, img_height)
|
||||
return [left, right, top, bottom]
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import random
|
||||
import cv2
|
||||
|
||||
|
||||
class DatasetSampler(object):
|
||||
def __init__(self, config):
|
||||
self.image_home = config["StyleSampler"]["image_home"]
|
||||
label_file = config["StyleSampler"]["label_file"]
|
||||
self.dataset_with_label = config["StyleSampler"]["with_label"]
|
||||
self.height = config["Global"]["image_height"]
|
||||
self.index = 0
|
||||
with open(label_file, "r") as f:
|
||||
label_raw = f.read()
|
||||
self.path_label_list = label_raw.split("\n")[:-1]
|
||||
assert len(self.path_label_list) > 0
|
||||
random.shuffle(self.path_label_list)
|
||||
|
||||
def sample(self):
|
||||
if self.index >= len(self.path_label_list):
|
||||
random.shuffle(self.path_label_list)
|
||||
self.index = 0
|
||||
if self.dataset_with_label:
|
||||
path_label = self.path_label_list[self.index]
|
||||
rel_image_path, label = path_label.split('\t')
|
||||
else:
|
||||
rel_image_path = self.path_label_list[self.index]
|
||||
label = None
|
||||
img_path = "{}/{}".format(self.image_home, rel_image_path)
|
||||
image = cv2.imread(img_path)
|
||||
origin_height = image.shape[0]
|
||||
ratio = self.height / origin_height
|
||||
width = int(image.shape[1] * ratio)
|
||||
height = int(image.shape[0] * ratio)
|
||||
image = cv2.resize(image, (width, height))
|
||||
|
||||
self.index += 1
|
||||
if label:
|
||||
return {"image": image, "label": label}
|
||||
else:
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def duplicate_image(image, width):
|
||||
image_width = image.shape[1]
|
||||
dup_num = width // image_width + 1
|
||||
image = np.tile(image, reps=[1, dup_num, 1])
|
||||
cropped_image = image[:, :width, :]
|
||||
return cropped_image
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
from utils.config import ArgsParser, load_config, override_config
|
||||
from utils.logging import get_logger
|
||||
from engine import style_samplers, corpus_generators, text_drawers, predictors, writers
|
||||
|
||||
|
||||
class ImageSynthesiser(object):
|
||||
def __init__(self):
|
||||
self.FLAGS = ArgsParser().parse_args()
|
||||
self.config = load_config(self.FLAGS.config)
|
||||
self.config = override_config(self.config, options=self.FLAGS.override)
|
||||
self.output_dir = self.config["Global"]["output_dir"]
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.mkdir(self.output_dir)
|
||||
self.logger = get_logger(
|
||||
log_file='{}/predict.log'.format(self.output_dir))
|
||||
|
||||
self.text_drawer = text_drawers.StdTextDrawer(self.config)
|
||||
|
||||
predictor_method = self.config["Predictor"]["method"]
|
||||
assert predictor_method is not None
|
||||
self.predictor = getattr(predictors, predictor_method)(self.config)
|
||||
|
||||
def synth_image(self, corpus, style_input, language="en"):
|
||||
corpus, text_input = self.text_drawer.draw_text(corpus, language)
|
||||
synth_result = self.predictor.predict(style_input, text_input)
|
||||
return synth_result
|
||||
|
||||
|
||||
class DatasetSynthesiser(ImageSynthesiser):
|
||||
def __init__(self):
|
||||
super(DatasetSynthesiser, self).__init__()
|
||||
self.tag = self.FLAGS.tag
|
||||
self.output_num = self.config["Global"]["output_num"]
|
||||
corpus_generator_method = self.config["CorpusGenerator"]["method"]
|
||||
self.corpus_generator = getattr(corpus_generators,
|
||||
corpus_generator_method)(self.config)
|
||||
|
||||
style_sampler_method = self.config["StyleSampler"]["method"]
|
||||
assert style_sampler_method is not None
|
||||
self.style_sampler = style_samplers.DatasetSampler(self.config)
|
||||
self.writer = writers.SimpleWriter(self.config, self.tag)
|
||||
|
||||
def synth_dataset(self):
|
||||
for i in range(self.output_num):
|
||||
style_data = self.style_sampler.sample()
|
||||
style_input = style_data["image"]
|
||||
corpus_language, text_input_label = self.corpus_generator.generate(
|
||||
)
|
||||
text_input_label, text_input = self.text_drawer.draw_text(
|
||||
text_input_label, corpus_language)
|
||||
|
||||
synth_result = self.predictor.predict(style_input, text_input)
|
||||
fake_fusion = synth_result["fake_fusion"]
|
||||
self.writer.save_image(fake_fusion, text_input_label)
|
||||
self.writer.save_label()
|
||||
self.writer.merge_label()
|
|
@ -0,0 +1,57 @@
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
class StdTextDrawer(object):
|
||||
def __init__(self, config):
|
||||
self.logger = get_logger()
|
||||
self.max_width = config["Global"]["image_width"]
|
||||
self.char_list = " 0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
self.height = config["Global"]["image_height"]
|
||||
self.font_dict = {}
|
||||
self.load_fonts(config["TextDrawer"]["fonts"])
|
||||
self.support_languages = list(self.font_dict)
|
||||
|
||||
def load_fonts(self, fonts_config):
|
||||
for language in fonts_config:
|
||||
font_path = fonts_config[language]
|
||||
font_height = self.get_valid_height(font_path)
|
||||
font = ImageFont.truetype(font_path, font_height)
|
||||
self.font_dict[language] = font
|
||||
|
||||
def get_valid_height(self, font_path):
|
||||
font = ImageFont.truetype(font_path, self.height - 4)
|
||||
_, font_height = font.getsize(self.char_list)
|
||||
if font_height <= self.height - 4:
|
||||
return self.height - 4
|
||||
else:
|
||||
return int((self.height - 4)**2 / font_height)
|
||||
|
||||
def draw_text(self, corpus, language="en", crop=True):
|
||||
if language not in self.support_languages:
|
||||
self.logger.warning(
|
||||
"language {} not supported, use en instead.".format(language))
|
||||
language = "en"
|
||||
if crop:
|
||||
width = min(self.max_width, len(corpus) * self.height) + 4
|
||||
else:
|
||||
width = len(corpus) * self.height + 4
|
||||
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
||||
draw = ImageDraw.Draw(bg)
|
||||
|
||||
char_x = 2
|
||||
font = self.font_dict[language]
|
||||
for i, char_i in enumerate(corpus):
|
||||
char_size = font.getsize(char_i)[0]
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
if char_x >= width:
|
||||
corpus = corpus[0:i + 1]
|
||||
self.logger.warning("corpus length exceed limit: {}".format(
|
||||
corpus))
|
||||
break
|
||||
|
||||
text_input = np.array(bg).astype(np.uint8)
|
||||
text_input = text_input[:, 0:char_x, :]
|
||||
return corpus, text_input
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
class SimpleWriter(object):
|
||||
def __init__(self, config, tag):
|
||||
self.logger = get_logger()
|
||||
self.output_dir = config["Global"]["output_dir"]
|
||||
self.counter = 0
|
||||
self.label_dict = {}
|
||||
self.tag = tag
|
||||
self.label_file_index = 0
|
||||
|
||||
def save_image(self, image, text_input_label):
|
||||
image_home = os.path.join(self.output_dir, "images", self.tag)
|
||||
if not os.path.exists(image_home):
|
||||
os.makedirs(image_home)
|
||||
|
||||
image_path = os.path.join(image_home, "{}.png".format(self.counter))
|
||||
# todo support continue synth
|
||||
cv2.imwrite(image_path, image)
|
||||
self.logger.info("generate image: {}".format(image_path))
|
||||
|
||||
image_name = os.path.join(self.tag, "{}.png".format(self.counter))
|
||||
self.label_dict[image_name] = text_input_label
|
||||
|
||||
self.counter += 1
|
||||
if not self.counter % 100:
|
||||
self.save_label()
|
||||
|
||||
def save_label(self):
|
||||
label_raw = ""
|
||||
label_home = os.path.join(self.output_dir, "label")
|
||||
if not os.path.exists(label_home):
|
||||
os.mkdir(label_home)
|
||||
for image_path in self.label_dict:
|
||||
label = self.label_dict[image_path]
|
||||
label_raw += "{}\t{}\n".format(image_path, label)
|
||||
label_file_path = os.path.join(label_home,
|
||||
"{}_label.txt".format(self.tag))
|
||||
with open(label_file_path, "w") as f:
|
||||
f.write(label_raw)
|
||||
self.label_file_index += 1
|
||||
|
||||
def merge_label(self):
|
||||
label_raw = ""
|
||||
label_file_regex = os.path.join(self.output_dir, "label",
|
||||
"*_label.txt")
|
||||
label_file_list = glob.glob(label_file_regex)
|
||||
for label_file_i in label_file_list:
|
||||
with open(label_file_i, "r") as f:
|
||||
label_raw += f.read()
|
||||
label_file_path = os.path.join(self.output_dir, "label.txt")
|
||||
with open(label_file_path, "w") as f:
|
||||
f.write(label_raw)
|
|
@ -0,0 +1,2 @@
|
|||
PaddleOCR
|
||||
飞桨文字识别
|
|
@ -0,0 +1,2 @@
|
|||
style_images/1.jpg NEATNESS
|
||||
style_images/2.jpg 锁店君和宾馆
|
After Width: | Height: | Size: 2.5 KiB |
After Width: | Height: | Size: 3.8 KiB |
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from engine.synthesisers import DatasetSynthesiser
|
||||
|
||||
|
||||
def synth_dataset():
|
||||
dataset_synthesiser = DatasetSynthesiser()
|
||||
dataset_synthesiser.synth_dataset()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
synth_dataset()
|
|
@ -0,0 +1,82 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import cv2
|
||||
import sys
|
||||
import glob
|
||||
|
||||
from utils.config import ArgsParser
|
||||
from engine.synthesisers import ImageSynthesiser
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
|
||||
def synth_image():
|
||||
args = ArgsParser().parse_args()
|
||||
image_synthesiser = ImageSynthesiser()
|
||||
style_image_path = args.style_image
|
||||
img = cv2.imread(style_image_path)
|
||||
text_corpus = args.text_corpus
|
||||
language = args.language
|
||||
|
||||
synth_result = image_synthesiser.synth_image(text_corpus, img, language)
|
||||
fake_fusion = synth_result["fake_fusion"]
|
||||
fake_text = synth_result["fake_text"]
|
||||
fake_bg = synth_result["fake_bg"]
|
||||
cv2.imwrite("fake_fusion.jpg", fake_fusion)
|
||||
cv2.imwrite("fake_text.jpg", fake_text)
|
||||
cv2.imwrite("fake_bg.jpg", fake_bg)
|
||||
|
||||
|
||||
def batch_synth_images():
|
||||
image_synthesiser = ImageSynthesiser()
|
||||
|
||||
corpus_file = "../StyleTextRec_data/test_20201208/test_text_list.txt"
|
||||
style_data_dir = "../StyleTextRec_data/test_20201208/style_images/"
|
||||
save_path = "./output_data/"
|
||||
corpus_list = []
|
||||
with open(corpus_file, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
substr = line.decode("utf-8").strip("\n").split("\t")
|
||||
corpus_list.append(substr)
|
||||
style_img_list = glob.glob("{}/*.jpg".format(style_data_dir))
|
||||
corpus_num = len(corpus_list)
|
||||
style_img_num = len(style_img_list)
|
||||
for cno in range(corpus_num):
|
||||
for sno in range(style_img_num):
|
||||
corpus, lang = corpus_list[cno]
|
||||
style_img_path = style_img_list[sno]
|
||||
img = cv2.imread(style_img_path)
|
||||
synth_result = image_synthesiser.synth_image(corpus, img, lang)
|
||||
fake_fusion = synth_result["fake_fusion"]
|
||||
fake_text = synth_result["fake_text"]
|
||||
fake_bg = synth_result["fake_bg"]
|
||||
for tp in range(2):
|
||||
if tp == 0:
|
||||
prefix = "%s/c%d_s%d_" % (save_path, cno, sno)
|
||||
else:
|
||||
prefix = "%s/s%d_c%d_" % (save_path, sno, cno)
|
||||
cv2.imwrite("%s_fake_fusion.jpg" % prefix, fake_fusion)
|
||||
cv2.imwrite("%s_fake_text.jpg" % prefix, fake_text)
|
||||
cv2.imwrite("%s_fake_bg.jpg" % prefix, fake_bg)
|
||||
cv2.imwrite("%s_input_style.jpg" % prefix, img)
|
||||
print(cno, corpus_num, sno, style_img_num)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# batch_synth_images()
|
||||
synth_image()
|
|
@ -0,0 +1,224 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import yaml
|
||||
import os
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
|
||||
|
||||
def override(dl, ks, v):
|
||||
"""
|
||||
Recursively replace dict of list
|
||||
|
||||
Args:
|
||||
dl(dict or list): dict or list to be replaced
|
||||
ks(list): list of keys
|
||||
v(str): value to be replaced
|
||||
"""
|
||||
|
||||
def str2num(v):
|
||||
try:
|
||||
return eval(v)
|
||||
except Exception:
|
||||
return v
|
||||
|
||||
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
|
||||
assert len(ks) > 0, ('lenght of keys should larger than 0')
|
||||
if isinstance(dl, list):
|
||||
k = str2num(ks[0])
|
||||
if len(ks) == 1:
|
||||
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
|
||||
dl[k] = str2num(v)
|
||||
else:
|
||||
override(dl[k], ks[1:], v)
|
||||
else:
|
||||
if len(ks) == 1:
|
||||
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||||
if not ks[0] in dl:
|
||||
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||||
dl[ks[0]] = str2num(v)
|
||||
else:
|
||||
assert ks[0] in dl, (
|
||||
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
||||
format(ks[0], dl))
|
||||
override(dl[ks[0]], ks[1:], v)
|
||||
|
||||
|
||||
def override_config(config, options=None):
|
||||
"""
|
||||
Recursively override the config
|
||||
|
||||
Args:
|
||||
config(dict): dict to be replaced
|
||||
options(list): list of pairs(key0.key1.idx.key2=value)
|
||||
such as: [
|
||||
'topk=2',
|
||||
'VALID.transforms.1.ResizeImage.resize_short=300'
|
||||
]
|
||||
|
||||
Returns:
|
||||
config(dict): replaced config
|
||||
"""
|
||||
if options is not None:
|
||||
for opt in options:
|
||||
assert isinstance(opt, str), (
|
||||
"option({}) should be a str".format(opt))
|
||||
assert "=" in opt, (
|
||||
"option({}) should contain a ="
|
||||
"to distinguish between key and value".format(opt))
|
||||
pair = opt.split('=')
|
||||
assert len(pair) == 2, ("there can be only a = in the option")
|
||||
key, value = pair
|
||||
keys = key.split('.')
|
||||
override(config, keys, value)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
formatter_class=RawDescriptionHelpFormatter)
|
||||
self.add_argument("-c", "--config", help="configuration file to use")
|
||||
self.add_argument(
|
||||
"-t", "--tag", default="0", help="tag for marking worker")
|
||||
self.add_argument(
|
||||
'-o',
|
||||
'--override',
|
||||
action='append',
|
||||
default=[],
|
||||
help='config options to be overridden')
|
||||
self.add_argument(
|
||||
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
|
||||
self.add_argument(
|
||||
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
|
||||
self.add_argument(
|
||||
"--language", default="en", help="tag for marking worker")
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
assert args.config is not None, \
|
||||
"Please specify --config=configure_file_path."
|
||||
return args
|
||||
|
||||
|
||||
def load_config(file_path):
|
||||
"""
|
||||
Load config from yml/yaml file.
|
||||
Args:
|
||||
file_path (str): Path of the config file to be loaded.
|
||||
Returns: config
|
||||
"""
|
||||
ext = os.path.splitext(file_path)[1]
|
||||
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||||
with open(file_path, 'rb') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def gen_config():
|
||||
base_config = {
|
||||
"Global": {
|
||||
"algorithm": "SRNet",
|
||||
"use_gpu": True,
|
||||
"start_epoch": 1,
|
||||
"stage1_epoch_num": 100,
|
||||
"stage2_epoch_num": 100,
|
||||
"log_smooth_window": 20,
|
||||
"print_batch_step": 2,
|
||||
"save_model_dir": "./output/SRNet",
|
||||
"use_visualdl": False,
|
||||
"save_epoch_step": 10,
|
||||
"vgg_pretrain": "./pretrained/VGG19_pretrained",
|
||||
"vgg_load_static_pretrain": True
|
||||
},
|
||||
"Architecture": {
|
||||
"model_type": "data_aug",
|
||||
"algorithm": "SRNet",
|
||||
"net_g": {
|
||||
"name": "srnet_net_g",
|
||||
"encode_dim": 64,
|
||||
"norm": "batch",
|
||||
"use_dropout": False,
|
||||
"init_type": "xavier",
|
||||
"init_gain": 0.02,
|
||||
"use_dilation": 1
|
||||
},
|
||||
# input_nc, ndf, netD,
|
||||
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
|
||||
"bg_discriminator": {
|
||||
"name": "srnet_bg_discriminator",
|
||||
"input_nc": 6,
|
||||
"ndf": 64,
|
||||
"netD": "basic",
|
||||
"norm": "none",
|
||||
"init_type": "xavier",
|
||||
},
|
||||
"fusion_discriminator": {
|
||||
"name": "srnet_fusion_discriminator",
|
||||
"input_nc": 6,
|
||||
"ndf": 64,
|
||||
"netD": "basic",
|
||||
"norm": "none",
|
||||
"init_type": "xavier",
|
||||
}
|
||||
},
|
||||
"Loss": {
|
||||
"lamb": 10,
|
||||
"perceptual_lamb": 1,
|
||||
"muvar_lamb": 50,
|
||||
"style_lamb": 500
|
||||
},
|
||||
"Optimizer": {
|
||||
"name": "Adam",
|
||||
"learning_rate": {
|
||||
"name": "lambda",
|
||||
"lr": 0.0002,
|
||||
"lr_decay_iters": 50
|
||||
},
|
||||
"beta1": 0.5,
|
||||
"beta2": 0.999,
|
||||
},
|
||||
"Train": {
|
||||
"batch_size_per_card": 8,
|
||||
"num_workers_per_card": 4,
|
||||
"dataset": {
|
||||
"delimiter": "\t",
|
||||
"data_dir": "/",
|
||||
"label_file": "tmp/label.txt",
|
||||
"transforms": [{
|
||||
"DecodeImage": {
|
||||
"to_rgb": True,
|
||||
"to_np": False,
|
||||
"channel_first": False
|
||||
}
|
||||
}, {
|
||||
"NormalizeImage": {
|
||||
"scale": 1. / 255.,
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
"order": None
|
||||
}
|
||||
}, {
|
||||
"ToCHWImage": None
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
with open("config.yml", "w") as f:
|
||||
yaml.dump(base_config, f)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
gen_config()
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import paddle
|
||||
|
||||
__all__ = ['load_dygraph_pretrain']
|
||||
|
||||
|
||||
def load_dygraph_pretrain(model, logger, path=None, load_static_weights=False):
|
||||
if not os.path.exists(path + '.pdparams'):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(path))
|
||||
param_state_dict = paddle.load(path + '.pdparams')
|
||||
model.set_state_dict(param_state_dict)
|
||||
logger.info("load pretrained model from {}".format(path))
|
||||
return
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import functools
|
||||
import paddle.distributed as dist
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name='srnet', log_file=None, log_level=logging.INFO):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added. If `log_file` is specified a FileHandler will also be added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
log_file (str | None): The log filename. If specified, a FileHandler
|
||||
will be added to the logger.
|
||||
log_level (int): The logger level. Note that only the process of
|
||||
rank 0 is affected, and other processes will set the level to
|
||||
"Error" thus be silent most of the time.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
|
||||
datefmt="%Y/%m/%d %H:%M:%S")
|
||||
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(formatter)
|
||||
logger.addHandler(stream_handler)
|
||||
if log_file is not None and dist.get_rank() == 0:
|
||||
log_file_folder = os.path.split(log_file)[0]
|
||||
os.makedirs(log_file_folder, exist_ok=True)
|
||||
file_handler = logging.FileHandler(log_file, 'a')
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
if dist.get_rank() == 0:
|
||||
logger.setLevel(log_level)
|
||||
else:
|
||||
logger.setLevel(logging.ERROR)
|
||||
logger_initialized[name] = True
|
||||
return logger
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddle
|
||||
|
||||
|
||||
def compute_mean_covariance(img):
|
||||
batch_size = img.shape[0]
|
||||
channel_num = img.shape[1]
|
||||
height = img.shape[2]
|
||||
width = img.shape[3]
|
||||
num_pixels = height * width
|
||||
|
||||
# batch_size * channel_num * 1 * 1
|
||||
mu = img.mean(2, keepdim=True).mean(3, keepdim=True)
|
||||
|
||||
# batch_size * channel_num * num_pixels
|
||||
img_hat = img - mu.expand_as(img)
|
||||
img_hat = img_hat.reshape([batch_size, channel_num, num_pixels])
|
||||
# batch_size * num_pixels * channel_num
|
||||
img_hat_transpose = img_hat.transpose([0, 2, 1])
|
||||
# batch_size * channel_num * channel_num
|
||||
covariance = paddle.bmm(img_hat, img_hat_transpose)
|
||||
covariance = covariance / num_pixels
|
||||
|
||||
return mu, covariance
|
||||
|
||||
|
||||
def dice_coefficient(y_true_cls, y_pred_cls, training_mask):
|
||||
eps = 1e-5
|
||||
intersection = paddle.sum(y_true_cls * y_pred_cls * training_mask)
|
||||
union = paddle.sum(y_true_cls * training_mask) + paddle.sum(
|
||||
y_pred_cls * training_mask) + eps
|
||||
loss = 1. - (2 * intersection / union)
|
||||
return loss
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import os
|
||||
import errno
|
||||
import paddle
|
||||
|
||||
|
||||
def get_check_global_params(mode):
|
||||
check_params = [
|
||||
'use_gpu', 'max_text_length', 'image_shape', 'image_shape',
|
||||
'character_type', 'loss_type'
|
||||
]
|
||||
if mode == "train_eval":
|
||||
check_params = check_params + [
|
||||
'train_batch_size_per_card', 'test_batch_size_per_card'
|
||||
]
|
||||
elif mode == "test":
|
||||
check_params = check_params + ['test_batch_size_per_card']
|
||||
return check_params
|
||||
|
||||
|
||||
def check_gpu(use_gpu):
|
||||
"""
|
||||
Log error and exit when set use_gpu=true in paddlepaddle
|
||||
cpu version.
|
||||
"""
|
||||
err = "Config use_gpu cannot be set as true while you are " \
|
||||
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
||||
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
||||
"\t2. Set use_gpu as false in config file to run " \
|
||||
"model on CPU"
|
||||
if use_gpu:
|
||||
try:
|
||||
if not paddle.is_compiled_with_cuda():
|
||||
print(err)
|
||||
sys.exit(1)
|
||||
except:
|
||||
print("Fail to check gpu state.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _mkdir_if_not_exist(path, logger):
|
||||
"""
|
||||
mkdir if not exists, ignore the exception when multiprocess mkdir together
|
||||
"""
|
||||
if not os.path.exists(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST and os.path.isdir(path):
|
||||
logger.warning(
|
||||
'be happy if some process has already created {}'.format(
|
||||
path))
|
||||
else:
|
||||
raise OSError('Failed to mkdir {}'.format(path))
|
|
@ -36,12 +36,13 @@ Architecture:
|
|||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 34
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 256
|
||||
hidden_size: 96
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
### 1.文本检测算法
|
||||
|
||||
PaddleOCR开源的文本检测算法列表:
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))(ppocr推荐)
|
||||
- [x] DB([paper]( https://arxiv.org/abs/1911.08947) )(ppocr推荐)
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
|
||||
|
||||
|
@ -38,9 +38,9 @@ PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训
|
|||
### 2.文本识别算法
|
||||
|
||||
PaddleOCR基于动态图开源的文本识别算法列表:
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))(ppocr推荐)
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717) )(ppocr推荐)
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
|
||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294)) coming soon
|
||||
|
||||
|
|
|
@ -62,9 +62,9 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
|
|||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```
|
||||
# GPU训练 支持单卡,多卡训练,通过selected_gpus指定卡号
|
||||
# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
|
||||
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
|
||||
python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||
```
|
||||
|
||||
- 数据增强
|
||||
|
@ -74,7 +74,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
|
|||
默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse),随机数据增强(RandAugment)。
|
||||
|
||||
训练过程中除随机数据增强外每种扰动方式以50%的概率被选择,具体代码实现请参考:
|
||||
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
|
||||
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
|
||||
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
|
||||
|
||||
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
|
||||
|
|
|
@ -107,17 +107,13 @@ PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall
|
|||
|
||||
运行如下代码,根据配置文件`det_db_mv3.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。
|
||||
|
||||
评估时设置后处理参数`box_thresh=0.6`,`unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
评估时设置后处理参数`box_thresh=0.5`,`unclip_ratio=1.5`,使用不同数据集、不同模型训练,可调整这两个参数进行优化
|
||||
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
|
||||
|
||||
比如:
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.5 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
|
||||
|
||||
* 注:`box_thresh`、`unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
|
||||
|
||||
## 测试检测效果
|
||||
|
|
|
@ -22,9 +22,8 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [三、文本识别模型推理](#文本识别模型推理)
|
||||
- [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理)
|
||||
- [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理)
|
||||
- [3. 基于Attention损失的识别模型推理](#基于Attention损失的识别模型推理)
|
||||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||
- [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
- [4. 多语言模型的推理](#多语言模型的推理)
|
||||
|
||||
- [四、方向分类模型推理](#方向识别模型推理)
|
||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||
|
@ -129,24 +128,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
|
|||
超轻量中文检测模型推理,可以执行如下命令:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
|
||||
# 下载超轻量中文检测模型:
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
tar xf ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./ch_ppocr_mobile_v2.0_det_infer/"
|
||||
```
|
||||
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'det_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/det_res_2.jpg)
|
||||
![](../imgs_results/det_res_22.jpg)
|
||||
|
||||
通过参数`limit_type`和`det_limit_side_len`来对图片的尺寸进行限制限,`limit_type=max`为限制长边长度<`det_limit_side_len`,`limit_type=min`为限制短边长度>`det_limit_side_len`,
|
||||
图片不满足限制条件时(`limit_type=max`时长边长度>`det_limit_side_len`或`limit_type=min`时短边长度<`det_limit_side_len`),将对图片进行等比例缩放。
|
||||
该参数默认设置为`limit_type='max',det_max_side_len=960`。 如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以执行如下命令:
|
||||
通过参数`limit_type`和`det_limit_side_len`来对图片的尺寸进行限制,
|
||||
`litmit_type`可选参数为[`max`, `min`],
|
||||
`det_limit_size_len` 为正整数,一般设置为32 的倍数,比如960。
|
||||
|
||||
参数默认设置为`limit_type='max', det_limit_side_len=960`。表示网络输入图像的最长边不能超过960,
|
||||
如果超过这个值,会对图像做等宽比的resize操作,确保最长边为`det_limit_side_len`。
|
||||
设置为`limit_type='min', det_limit_side_len=960` 则表示限制图像的最短边为960。
|
||||
|
||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
如果想使用CPU进行预测,执行命令如下
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
|
@ -268,16 +275,6 @@ CRNN 文本识别模型推理,可以执行如下命令:
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
```
|
||||
|
||||
<a name="基于Attention损失的识别模型推理"></a>
|
||||
### 3. 基于Attention损失的识别模型推理
|
||||
|
||||
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
|
||||
RARE 文本识别模型推理,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
|
||||
|
||||
```
|
||||
|
||||
![](../imgs_words_en/word_336.png)
|
||||
|
||||
执行命令后,上面图像的识别结果如下:
|
||||
|
@ -297,7 +294,7 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
|||
dict_character = list(self.character_str)
|
||||
```
|
||||
|
||||
### 4. 自定义文本识别字典的推理
|
||||
### 3. 自定义文本识别字典的推理
|
||||
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
|
||||
|
||||
```
|
||||
|
@ -305,7 +302,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
|||
```
|
||||
|
||||
<a name="多语言模型的推理"></a>
|
||||
### 5. 多语言模型的推理
|
||||
### 4. 多语言模型的推理
|
||||
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
|
||||
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/` 路径下有默认提供的小语种字体,例如韩文识别:
|
||||
|
||||
|
|
|
@ -167,7 +167,7 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
|
|||
|
||||
```
|
||||
# GPU训练 支持单卡,多卡训练,通过--gpus参数指定卡号
|
||||
# 训练icdar15英文数据 并将训练日志保存为 tain_rec.log
|
||||
# 训练icdar15英文数据 训练日志会自动保存为 "{save_model_dir}" 下的train.log
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
```
|
||||
<a name="数据增强"></a>
|
||||
|
@ -200,11 +200,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
|
|||
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
|
||||
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
|
||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
|
||||
|
||||
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
|
||||
|
||||
|
@ -356,8 +353,7 @@ python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkp
|
|||
|
||||
```
|
||||
infer_img: doc/imgs_words/en/word_1.png
|
||||
index: [19 24 18 23 29]
|
||||
word : joint
|
||||
result: ('joint', 0.9998967)
|
||||
```
|
||||
|
||||
预测使用的配置文件必须与训练一致,如您通过 `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml` 完成了中文模型的训练,
|
||||
|
@ -376,6 +372,5 @@ python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v
|
|||
|
||||
```
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
index: [2092 177 312 2503]
|
||||
word : 韩国小馆
|
||||
result: ('韩国小馆', 0.997218)
|
||||
```
|
||||
|
|
|
@ -13,7 +13,7 @@ This tutorial lists the text detection algorithms and text recognition algorithm
|
|||
PaddleOCR open source text detection algorithms list:
|
||||
- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
|
||||
- [x] DB([paper](https://arxiv.org/abs/1911.08947))
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))(Baidu Self-Research)
|
||||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498) )(Baidu Self-Research)
|
||||
|
||||
On the ICDAR2015 dataset, the text detection result is as follows:
|
||||
|
||||
|
@ -41,9 +41,9 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
|
|||
PaddleOCR open-source text recognition algorithms list:
|
||||
- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
|
||||
- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
|
||||
- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [ ] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
|
||||
- [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1)) coming soon
|
||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))(Baidu Self-Research) coming soon
|
||||
- [ ] SRN([paper](https://arxiv.org/abs/2003.12294) )(Baidu Self-Research) coming soon
|
||||
|
||||
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
|
||||
|
||||
|
|
|
@ -65,9 +65,9 @@ Start training:
|
|||
```
|
||||
# Set PYTHONPATH path
|
||||
export PYTHONPATH=$PYTHONPATH:.
|
||||
# GPU training Support single card and multi-card training, specify the card number through selected_gpus
|
||||
# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
|
||||
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
|
||||
python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
|
||||
```
|
||||
|
||||
- Data Augmentation
|
||||
|
@ -77,7 +77,7 @@ PaddleOCR provides a variety of data augmentation methods. If you want to add di
|
|||
The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, RandAugment.
|
||||
|
||||
Except for RandAugment, each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to:
|
||||
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
|
||||
[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
|
||||
[randaugment.py](../../ppocr/data/imaug/randaugment.py)
|
||||
|
||||
|
||||
|
|
|
@ -101,15 +101,11 @@ Run the following code to calculate the evaluation indicators. The result will b
|
|||
|
||||
When evaluating, set post-processing parameters `box_thresh=0.6`, `unclip_ratio=1.5`. If you use different datasets, different models for training, these two parameters should be adjusted for better result.
|
||||
|
||||
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
The model parameters during training are saved in the `Global.save_model_dir` directory by default. When evaluating indicators, you need to set `Global.checkpoints` to point to the saved parameter file.
|
||||
|
||||
Such as:
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
|
||||
```
|
||||
|
||||
* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST model.
|
||||
|
||||
|
|
|
@ -25,9 +25,8 @@ Next, we first introduce how to convert a trained model into an inference model,
|
|||
- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
|
||||
- [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
|
||||
- [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
|
||||
- [3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE](#ATTENTION-BASED_RECOGNITION)
|
||||
- [4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
||||
- [5. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
|
||||
- [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
|
||||
- [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
|
||||
|
||||
- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
|
||||
- [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
|
||||
|
@ -135,24 +134,33 @@ Because EAST and DB algorithms are very different, when inference, it is necessa
|
|||
For lightweight Chinese detection model inference, you can execute the following commands:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/"
|
||||
# download DB text detection inference model
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
tar xf ch_ppocr_mobile_v2.0_det_infer.tar
|
||||
# predict
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/"
|
||||
```
|
||||
|
||||
The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows:
|
||||
|
||||
![](../imgs_results/det_res_2.jpg)
|
||||
![](../imgs_results/det_res_22.jpg)
|
||||
|
||||
The size of the image is limited by the parameters `limit_type` and `det_limit_side_len`, `limit_type=max` is to limit the length of the long side <`det_limit_side_len`, and `limit_type=min` is to limit the length of the short side>`det_limit_side_len`,
|
||||
When the picture does not meet the restriction conditions (for `limit_type=max`and long side >`det_limit_side_len` or for `min` and short side <`det_limit_side_len`), the image will be scaled proportionally.
|
||||
This parameter is set to `limit_type='max', det_max_side_len=960` by default. If the resolution of the input picture is relatively large, and you want to use a larger resolution prediction, you can execute the following command:
|
||||
You can use the parameters `limit_type` and `det_limit_side_len` to limit the size of the input image,
|
||||
The optional parameters of `litmit_type` are [`max`, `min`], and
|
||||
`det_limit_size_len` is a positive integer, generally set to a multiple of 32, such as 960.
|
||||
|
||||
The default setting of the parameters is `limit_type='max', det_limit_side_len=960`. Indicates that the longest side of the network input image cannot exceed 960,
|
||||
If this value is exceeded, the image will be resized with the same width ratio to ensure that the longest side is `det_limit_side_len`.
|
||||
Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest side of the image is limited to 960.
|
||||
|
||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1200
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
If you want to use the CPU for prediction, execute the command as follows
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
|
@ -275,15 +283,6 @@ For CRNN text recognition model inference, execute the following commands:
|
|||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
|
||||
```
|
||||
|
||||
<a name="ATTENTION-BASED_RECOGNITION"></a>
|
||||
### 3. ATTENTION-BASED TEXT RECOGNITION MODEL INFERENCE
|
||||
|
||||
The recognition model based on Attention loss is different from ctc, and additional recognition algorithm parameters need to be set --rec_algorithm="RARE"
|
||||
After executing the command, the recognition result of the above image is as follows:
|
||||
```bash
|
||||
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
|
||||
```
|
||||
|
||||
![](../imgs_words_en/word_336.png)
|
||||
|
||||
After executing the command, the recognition result of the above image is as follows:
|
||||
|
@ -303,7 +302,7 @@ dict_character = list(self.character_str)
|
|||
```
|
||||
|
||||
<a name="USING_CUSTOM_CHARACTERS"></a>
|
||||
### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
|
||||
### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
|
||||
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
|
||||
|
||||
```
|
||||
|
@ -311,7 +310,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
|
|||
```
|
||||
|
||||
<a name="MULTILINGUAL_MODEL_INFERENCE"></a>
|
||||
### 5. MULTILINGAUL MODEL INFERENCE
|
||||
### 4. MULTILINGAUL MODEL INFERENCE
|
||||
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
|
||||
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/` path, such as Korean recognition:
|
||||
|
||||
|
|
|
@ -162,7 +162,7 @@ Start training:
|
|||
|
||||
```
|
||||
# GPU training Support single card and multi-card training, specify the card number through --gpus
|
||||
# Training icdar15 English data and saving the log as train_rec.log
|
||||
# Training icdar15 English data and The training log will be automatically saved as train.log under "{save_model_dir}"
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
```
|
||||
<a name="Data_Augmentation"></a>
|
||||
|
@ -193,11 +193,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
|
|||
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
|
||||
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
|
||||
| rec_mv3_tps_bilstm_ctc.yml | STARNet | Mobilenet_v3 large 0.5 | tps | BiLSTM | ctc |
|
||||
| rec_mv3_tps_bilstm_attn.yml | RARE | Mobilenet_v3 large 0.5 | tps | BiLSTM | attention |
|
||||
| rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc |
|
||||
| rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc |
|
||||
| rec_r34_vd_tps_bilstm_ctc.yml | STARNet | Resnet34_vd | tps | BiLSTM | ctc |
|
||||
|
||||
For training Chinese data, it is recommended to use
|
||||
[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file:
|
||||
|
@ -350,8 +347,7 @@ Get the prediction result of the input image:
|
|||
|
||||
```
|
||||
infer_img: doc/imgs_words/en/word_1.png
|
||||
index: [19 24 18 23 29]
|
||||
word : joint
|
||||
result: ('joint', 0.9998967)
|
||||
```
|
||||
|
||||
The configuration file used for prediction must be consistent with the training. For example, you completed the training of the Chinese model with `python3 tools/train.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml`, you can use the following command to predict the Chinese model:
|
||||
|
@ -369,6 +365,5 @@ Get the prediction result of the input image:
|
|||
|
||||
```
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
index: [2092 177 312 2503]
|
||||
word : 韩国小馆
|
||||
result: ('韩国小馆', 0.997218)
|
||||
```
|
||||
|
|
After Width: | Height: | Size: 76 KiB |
|
@ -180,7 +180,6 @@ class GridGenerator(nn.Layer):
|
|||
P = self.build_P_paddle(I_r_size)
|
||||
|
||||
inv_delta_C_tensor = self.build_inv_delta_C_paddle(C).astype('float32')
|
||||
# inv_delta_C_tensor = paddle.zeros((23,23)).astype('float32')
|
||||
P_hat_tensor = self.build_P_hat_paddle(
|
||||
C, paddle.to_tensor(P)).astype('float32')
|
||||
|
||||
|
|