Merge branch 'dygraph' into dygraph
|
@ -1,8 +1,9 @@
|
|||
include LICENSE.txt
|
||||
include LICENSE
|
||||
include README.md
|
||||
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py
|
||||
recursive-include ppocr/data/ *.py
|
||||
recursive-include ppocr/utils *.txt utility.py logging.py network.py
|
||||
recursive-include ppocr/data *.py
|
||||
recursive-include ppocr/postprocess *.py
|
||||
recursive-include tools/infer *.py
|
||||
recursive-include ppocr/utils/e2e_utils/ *.py
|
||||
recursive-include ppocr/utils/e2e_utils *.py
|
||||
recursive-include ppstructure *.py
|
|
@ -27,7 +27,12 @@ import json
|
|||
import cv2
|
||||
|
||||
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
sys.path.append("..")
|
||||
|
@ -267,6 +272,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.colorDialog = ColorDialog(parent=self)
|
||||
self.zoomWidgetValue = self.zoomWidget.value()
|
||||
|
||||
self.msgBox = QMessageBox()
|
||||
|
||||
########## thumbnail #########
|
||||
hlayout = QHBoxLayout()
|
||||
m = (0, 0, 0, 0)
|
||||
|
@ -360,6 +367,9 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
opendir = action(getStr('openDir'), self.openDirDialog,
|
||||
'Ctrl+u', 'open', getStr('openDir'))
|
||||
|
||||
open_dataset_dir = action(getStr('openDatasetDir'), self.openDatasetDirDialog,
|
||||
'Ctrl+p', 'open', getStr('openDatasetDir'), enabled=False)
|
||||
|
||||
save = action(getStr('save'), self.saveFile,
|
||||
'Ctrl+V', 'verify', getStr('saveDetail'), enabled=False)
|
||||
|
||||
|
@ -398,6 +408,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
help = action(getStr('tutorial'), self.showTutorialDialog, None, 'help', getStr('tutorialDetail'))
|
||||
showInfo = action(getStr('info'), self.showInfoDialog, None, 'help', getStr('info'))
|
||||
showSteps = action(getStr('steps'), self.showStepsDialog, None, 'help', getStr('steps'))
|
||||
showKeys = action(getStr('keys'), self.showKeysDialog, None, 'help', getStr('keys'))
|
||||
|
||||
zoom = QWidgetAction(self)
|
||||
zoom.setDefaultWidget(self.zoomWidget)
|
||||
|
@ -456,6 +467,12 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
undoLastPoint = action(getStr("undoLastPoint"), self.canvas.undoLastPoint,
|
||||
'Ctrl+Z', "undo", getStr("undoLastPoint"), enabled=False)
|
||||
|
||||
rotateLeft = action(getStr("rotateLeft"), partial(self.rotateImgAction,1),
|
||||
'Ctrl+Alt+L', "rotateLeft", getStr("rotateLeft"), enabled=False)
|
||||
|
||||
rotateRight = action(getStr("rotateRight"), partial(self.rotateImgAction,-1),
|
||||
'Ctrl+Alt+R', "rotateRight", getStr("rotateRight"), enabled=False)
|
||||
|
||||
undo = action(getStr("undo"), self.undoShapeEdit,
|
||||
'Ctrl+Z', "undo", getStr("undo"), enabled=False)
|
||||
|
||||
|
@ -519,13 +536,14 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
|
||||
fitWindow=fitWindow, fitWidth=fitWidth,
|
||||
zoomActions=zoomActions, saveLabel=saveLabel,
|
||||
undo=undo, undoLastPoint=undoLastPoint,
|
||||
undo=undo, undoLastPoint=undoLastPoint,open_dataset_dir=open_dataset_dir,
|
||||
rotateLeft=rotateLeft,rotateRight=rotateRight,
|
||||
fileMenuActions=(
|
||||
opendir, saveLabel, resetAll, quit),
|
||||
opendir, open_dataset_dir, saveLabel, resetAll, quit),
|
||||
beginner=(), advanced=(),
|
||||
editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint,
|
||||
None, color1, self.drawSquaresOption),
|
||||
beginnerContext=(create, edit, copy, delete, singleRere),
|
||||
None, rotateLeft, rotateRight, None, color1, self.drawSquaresOption),
|
||||
beginnerContext=(create, edit, copy, delete, singleRere, rotateLeft, rotateRight,),
|
||||
advancedContext=(createMode, editMode, edit, copy,
|
||||
delete, shapeLineColor, shapeFillColor),
|
||||
onLoadActive=(
|
||||
|
@ -563,9 +581,9 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
|
||||
|
||||
addActions(self.menus.file,
|
||||
(opendir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
(opendir, open_dataset_dir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
|
||||
addActions(self.menus.help, (showSteps, showInfo))
|
||||
addActions(self.menus.help, (showKeys,showSteps, showInfo))
|
||||
addActions(self.menus.view, (
|
||||
self.displayLabelOption, self.labelDialogOption,
|
||||
None,
|
||||
|
@ -760,6 +778,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
msg = stepsInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def showKeysDialog(self):
|
||||
msg = keysInfo(self.lang)
|
||||
QMessageBox.information(self, u'Information', msg)
|
||||
|
||||
def createShape(self):
|
||||
assert self.beginner()
|
||||
self.canvas.setEditing(False)
|
||||
|
@ -773,6 +795,38 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.actions.create.setEnabled(False)
|
||||
self.actions.undoLastPoint.setEnabled(True)
|
||||
|
||||
def rotateImg(self, filename, k, _value):
|
||||
|
||||
self.actions.rotateRight.setEnabled(_value)
|
||||
pix = cv2.imread(filename)
|
||||
pix = np.rot90(pix, k)
|
||||
cv2.imwrite(filename, pix)
|
||||
self.canvas.update()
|
||||
self.loadFile(filename)
|
||||
|
||||
def rotateImgWarn(self):
|
||||
if self.lang == 'ch':
|
||||
self.msgBox.warning (self, "提示", "\n 该图片已经有标注框,旋转操作会打乱标注,建议清除标注框后旋转。")
|
||||
else:
|
||||
self.msgBox.warning (self, "Warn", "\n The picture already has a label box, and rotation will disrupt the label.\
|
||||
It is recommended to clear the label box and rotate it.")
|
||||
|
||||
def rotateImgAction(self, k=1, _value=False):
|
||||
|
||||
filename = self.mImgList[self.currIndex]
|
||||
|
||||
if os.path.exists(filename):
|
||||
if self.itemsToShapesbox:
|
||||
self.rotateImgWarn()
|
||||
else:
|
||||
self.saveFile()
|
||||
self.dirty = False
|
||||
self.rotateImg(filename=filename, k=k, _value=True)
|
||||
else:
|
||||
self.rotateImgWarn()
|
||||
self.actions.rotateRight.setEnabled(False)
|
||||
self.actions.rotateLeft.setEnabled(False)
|
||||
|
||||
def toggleDrawingSensitive(self, drawing=True):
|
||||
"""In the middle of drawing, toggling between modes should be disabled."""
|
||||
self.actions.editMode.setEnabled(not drawing)
|
||||
|
@ -880,7 +934,12 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.updateComboBox()
|
||||
|
||||
def updateBoxlist(self):
|
||||
for shape in self.canvas.selectedShapes+[self.canvas.hShape]:
|
||||
self.canvas.selectedShapes_hShape = []
|
||||
if self.canvas.hShape != None:
|
||||
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes + [self.canvas.hShape]
|
||||
else:
|
||||
self.canvas.selectedShapes_hShape = self.canvas.selectedShapes
|
||||
for shape in self.canvas.selectedShapes_hShape:
|
||||
item = self.shapesToItemsbox[shape] # listitem
|
||||
text = [(int(p.x()), int(p.y())) for p in shape.points]
|
||||
item.setText(str(text))
|
||||
|
@ -1413,6 +1472,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
|
||||
def loadRecent(self, filename):
|
||||
if self.mayContinue():
|
||||
print(filename,"======")
|
||||
self.loadFile(filename)
|
||||
|
||||
def scanAllImages(self, folderPath):
|
||||
|
@ -1448,6 +1508,23 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.lastOpenDir = targetDirPath
|
||||
self.importDirImages(targetDirPath)
|
||||
|
||||
def openDatasetDirDialog(self,):
|
||||
if self.lastOpenDir and os.path.exists(self.lastOpenDir):
|
||||
if platform.system() == 'Windows':
|
||||
os.startfile(self.lastOpenDir)
|
||||
else:
|
||||
os.system('open ' + os.path.normpath(self.lastOpenDir))
|
||||
defaultOpenDirPath = self.lastOpenDir
|
||||
|
||||
else:
|
||||
if self.lang == 'ch':
|
||||
self.msgBox.warning(self, "提示", "\n 原文件夹已不存在,请从新选择数据集路径!")
|
||||
else:
|
||||
self.msgBox.warning(self, "Warn", "\n The original folder no longer exists, please choose the data set path again!")
|
||||
|
||||
self.actions.open_dataset_dir.setEnabled(False)
|
||||
defaultOpenDirPath = os.path.dirname(self.filePath) if self.filePath else '.'
|
||||
|
||||
def importDirImages(self, dirpath, isDelete = False):
|
||||
if not self.mayContinue() or not dirpath:
|
||||
return
|
||||
|
@ -1495,6 +1572,10 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.reRecogButton.setEnabled(True)
|
||||
self.actions.AutoRec.setEnabled(True)
|
||||
self.actions.reRec.setEnabled(True)
|
||||
self.actions.open_dataset_dir.setEnabled(True)
|
||||
self.actions.rotateLeft.setEnabled(True)
|
||||
self.actions.rotateRight.setEnabled(True)
|
||||
|
||||
|
||||
|
||||
def openPrevImg(self, _value=False):
|
||||
|
|
|
@ -8,9 +8,12 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
|
|||
|
||||
### Recent Update
|
||||
|
||||
- 2021.8.11:
|
||||
- New functions: Open the dataset folder, image rotation (Note: Please delete the label box before rotating the image) (by [Wei-JL](https://github.com/Wei-JL))
|
||||
- Added shortcut key description (Help-Shortcut Key), repaired the direction shortcut key movement function under batch processing (by [d2623587501](https://github.com/d2623587501))
|
||||
- 2021.2.5: New batch processing and undo functions (by [Evezerest](https://github.com/Evezerest)):
|
||||
- Batch processing function: Press and hold the Ctrl key to select the box, you can move, copy, and delete in batches.
|
||||
- Undo function: In the process of drawing a four-point label box or after editing the box, press Ctrl+Z to undo the previous operation.
|
||||
- **Batch processing function**: Press and hold the Ctrl key to select the box, you can move, copy, and delete in batches.
|
||||
- **Undo function**: In the process of drawing a four-point label box or after editing the box, press Ctrl+Z to undo the previous operation.
|
||||
- Fix image rotation and size problems, optimize the process of editing the mark frame (by [ninetailskim](https://github.com/ninetailskim)、 [edencfc](https://github.com/edencfc)).
|
||||
- 2021.1.11: Optimize the labeling experience (by [edencfc](https://github.com/edencfc)),
|
||||
- Users can choose whether to pop up the label input dialog after drawing the detection box in "View - Pop-up Label Input Dialog".
|
||||
|
@ -23,15 +26,51 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
|
|||
|
||||
## Installation
|
||||
|
||||
### 1. Install PaddleOCR
|
||||
### 1. Environment Preparation
|
||||
|
||||
PaddleOCR models has been built in PPOCRLabel, please refer to [PaddleOCR installation document](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/installation.md) to prepare PaddleOCR and make sure it works.
|
||||
#### **Install PaddlePaddle 2.0**
|
||||
|
||||
```bash
|
||||
pip3 install --upgrade pip
|
||||
|
||||
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
|
||||
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
# If you only have cpu on your machine, please run the following command to install
|
||||
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||
```
|
||||
|
||||
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
|
||||
|
||||
#### **Install PaddleOCR**
|
||||
|
||||
```bash
|
||||
# Recommend
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
# If you cannot pull successfully due to network problems, you can also choose to use the code hosting on the cloud:
|
||||
|
||||
git clone https://gitee.com/paddlepaddle/PaddleOCR
|
||||
|
||||
# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method.
|
||||
```
|
||||
|
||||
#### **Install Third-party Libraries**
|
||||
|
||||
```bash
|
||||
cd PaddleOCR
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows. Please try to download Shapely whl file using http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely.
|
||||
|
||||
Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
|
||||
|
||||
### 2. Install PPOCRLabel
|
||||
|
||||
#### Windows
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install pyqt5
|
||||
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
|
||||
python PPOCRLabel.py
|
||||
|
@ -39,15 +78,15 @@ python PPOCRLabel.py
|
|||
|
||||
#### Ubuntu Linux
|
||||
|
||||
```
|
||||
```bash
|
||||
pip3 install pyqt5
|
||||
pip3 install trash-cli
|
||||
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
|
||||
python3 PPOCRLabel.py
|
||||
```
|
||||
|
||||
#### macOS
|
||||
```
|
||||
#### MacOS
|
||||
```bash
|
||||
pip3 install pyqt5
|
||||
pip3 uninstall opencv-python # Uninstall opencv manually as it conflicts with pyqt
|
||||
pip3 install opencv-contrib-python-headless==4.2.0.32 # Install the headless version of opencv
|
||||
|
@ -77,11 +116,11 @@ python3 PPOCRLabel.py
|
|||
|
||||
7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.
|
||||
|
||||
8. Click "Check", the image status will switch to "√",then the program automatically jump to the next.
|
||||
8. **Click "Check", the image status will switch to "√",then the program automatically jump to the next.**
|
||||
|
||||
9. Click "Delete Image" and the image will be deleted to the recycle bin.
|
||||
|
||||
10. Labeling result: the user can save manually through the menu "File - Save Label", while the program will also save automatically if "File - Auto Save Label Mode" is selected. The manually checked label will be stored in *Label.txt* under the opened picture folder. Click "PaddleOCR"-"Save Recognition Results" in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*<sup>[4]</sup>.
|
||||
10. Labeling result: the user can export the label result manually through the menu "File - Export Label", while the program will also export automatically if "File - Auto export Label Mode" is selected. The manually checked label will be stored in *Label.txt* under the opened picture folder. Click "File"-"Export Recognition Results" in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*<sup>[4]</sup>.
|
||||
|
||||
### Note
|
||||
|
||||
|
@ -95,10 +134,10 @@ python3 PPOCRLabel.py
|
|||
|
||||
| File name | Description |
|
||||
| :-----------: | :----------------------------------------------------------: |
|
||||
| Label.txt | The detection label file can be directly used for PPOCR detection model training. After the user saves 5 label results, the file will be automatically saved. It will also be written when the user closes the application or changes the file folder. |
|
||||
| Label.txt | The detection label file can be directly used for PPOCR detection model training. After the user saves 5 label results, the file will be automatically exported. It will also be written when the user closes the application or changes the file folder. |
|
||||
| fileState.txt | The picture status file save the image in the current folder that has been manually confirmed by the user. |
|
||||
| Cache.cach | Cache files to save the results of model recognition. |
|
||||
| rec_gt.txt | The recognition label file, which can be directly used for PPOCR identification model training, is generated after the user clicks on the menu bar "File"-"Save recognition result". |
|
||||
| rec_gt.txt | The recognition label file, which can be directly used for PPOCR identification model training, is generated after the user clicks on the menu bar "File"-"Export recognition result". |
|
||||
| crop_img | The recognition data, generated at the same time with *rec_gt.txt* |
|
||||
|
||||
## Explanation
|
||||
|
@ -132,16 +171,16 @@ python3 PPOCRLabel.py
|
|||
|
||||
- Custom model: The model trained by users can be replaced by modifying PPOCRLabel.py in [PaddleOCR class instantiation](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/PPOCRLabel.py#L110) referring [Custom Model Code](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/whl_en.md#use-custom-model)
|
||||
|
||||
### Save
|
||||
### Export Label Result
|
||||
|
||||
PPOCRLabel supports three ways to save Label.txt
|
||||
PPOCRLabel supports three ways to export Label.txt
|
||||
|
||||
- Automatically save: After selecting "File - Auto Save Label Mode", the program will automatically write the annotations into Label.txt every time the user confirms an image. If this option is not turned on, it will be automatically saved after detecting that the user has manually checked 5 images.
|
||||
- Manual save: Click "File-Save Marking Results" to manually save the label.
|
||||
- Close application save
|
||||
- Automatically export: After selecting "File - Auto Export Label Mode", the program will automatically write the annotations into Label.txt every time the user confirms an image. If this option is not turned on, it will be automatically exported after detecting that the user has manually checked 5 images.
|
||||
- Manual export: Click "File-Export Marking Results" to manually export the label.
|
||||
- Close application export
|
||||
|
||||
|
||||
### Export partial recognition results
|
||||
### Export Partial Recognition Results
|
||||
|
||||
For some data that are difficult to recognize, the recognition results will not be exported by **unchecking** the corresponding tags in the recognition results checkbox.
|
||||
|
||||
|
|
|
@ -8,9 +8,12 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
|
|||
|
||||
#### 近期更新
|
||||
|
||||
- 2021.8.11:
|
||||
- 新增功能:打开数据所在文件夹、图像旋转(注意:旋转前的图片上不能存在标记框)(by [Wei-JL](https://github.com/Wei-JL))
|
||||
- 新增快捷键说明(帮助-快捷键)、修复批处理下的方向快捷键移动功能(by [d2623587501](https://github.com/d2623587501))
|
||||
- 2021.2.5:新增批处理与撤销功能(by [Evezerest](https://github.com/Evezerest))
|
||||
- 批处理功能:按住Ctrl键选择标记框后可批量移动、复制、删除。
|
||||
- 撤销功能:在绘制四点标注框过程中或对框进行编辑操作后,按下Ctrl+Z可撤销上一部操作。
|
||||
- **批处理功能**:按住Ctrl键选择标记框后可批量移动、复制、删除、重新识别。
|
||||
- **撤销功能**:在绘制四点标注框过程中或对框进行编辑操作后,按下Ctrl+Z可撤销上一部操作。
|
||||
- 修复图像旋转和尺寸问题、优化编辑标记框过程(by [ninetailskim](https://github.com/ninetailskim)、 [edencfc](https://github.com/edencfc))
|
||||
- 2021.1.11:优化标注体验(by [edencfc](https://github.com/edencfc)):
|
||||
- 用户可在“视图 - 弹出标记输入框”选择在画完检测框后标记输入框是否弹出。
|
||||
|
@ -27,13 +30,48 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
|
|||
|
||||
## 安装
|
||||
|
||||
### 1. 安装PaddleOCR
|
||||
PPOCRLabel内置PaddleOCR模型,故请参考[PaddleOCR安装文档](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/installation.md)准备好PaddleOCR,并确保PaddleOCR安装成功。
|
||||
### 1. 环境搭建
|
||||
#### 安装PaddlePaddle
|
||||
|
||||
```bash
|
||||
pip3 install --upgrade pip
|
||||
|
||||
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
|
||||
python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||
|
||||
如果您的机器是CPU,请运行以下命令安装
|
||||
|
||||
python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
|
||||
```
|
||||
|
||||
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
|
||||
|
||||
#### **安装PaddleOCR**
|
||||
|
||||
```bash
|
||||
【推荐】git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
如果因为网络问题无法pull成功,也可选择使用码云上的托管:
|
||||
|
||||
git clone https://gitee.com/paddlepaddle/PaddleOCR
|
||||
|
||||
注:码云托管代码可能无法实时同步本github项目更新,存在3~5天延时,请优先使用推荐方式。
|
||||
```
|
||||
|
||||
#### 安装第三方库
|
||||
|
||||
```bash
|
||||
cd PaddleOCR
|
||||
pip3 install -r requirements.txt
|
||||
```
|
||||
|
||||
注意,windows环境下,建议从[这里](https://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely)下载shapely安装包完成安装, 直接通过pip安装的shapely库可能出现`[winRrror 126] 找不到指定模块的问题`。
|
||||
|
||||
### 2. 安装PPOCRLabel
|
||||
|
||||
#### Windows
|
||||
|
||||
```
|
||||
```bash
|
||||
pip install pyqt5
|
||||
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
|
||||
python PPOCRLabel.py --lang ch
|
||||
|
@ -41,15 +79,15 @@ python PPOCRLabel.py --lang ch
|
|||
|
||||
#### Ubuntu Linux
|
||||
|
||||
```
|
||||
```bash
|
||||
pip3 install pyqt5
|
||||
pip3 install trash-cli
|
||||
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
|
||||
python3 PPOCRLabel.py --lang ch
|
||||
```
|
||||
|
||||
#### macOS
|
||||
```
|
||||
#### MacOS
|
||||
```bash
|
||||
pip3 install pyqt5
|
||||
pip3 uninstall opencv-python # 由于mac版本的opencv与pyqt有冲突,需先手动卸载opencv
|
||||
pip3 install opencv-contrib-python-headless==4.2.0.32 # 安装headless版本的open-cv
|
||||
|
@ -57,6 +95,8 @@ cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
|
|||
python3 PPOCRLabel.py --lang ch
|
||||
```
|
||||
|
||||
|
||||
|
||||
## 使用
|
||||
|
||||
### 操作步骤
|
||||
|
@ -68,9 +108,9 @@ python3 PPOCRLabel.py --lang ch
|
|||
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
|
||||
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>。
|
||||
7. 内容更改:双击识别结果,对不准确的识别结果进行手动更改。
|
||||
8. **确认标记**:点击 “确认”,图片状态切换为 “√”,跳转至下一张。
|
||||
8. **确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张。**
|
||||
9. 删除:点击 “删除图像”,图片将会被删除至回收站。
|
||||
10. 保存结果:用户可以通过菜单中“文件-保存标记结果”手动保存,同时也可以点击“文件 - 自动保存标记结果”开启自动保存。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "保存识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*中<sup>[4]</sup>。
|
||||
10. 导出结果:用户可以通过菜单中“文件-导出标记结果”手动导出,同时也可以点击“文件 - 自动导出标记结果”开启自动导出。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "导出识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*中<sup>[4]</sup>。
|
||||
|
||||
### 注意
|
||||
|
||||
|
@ -84,10 +124,10 @@ python3 PPOCRLabel.py --lang ch
|
|||
|
||||
| 文件名 | 说明 |
|
||||
| :-----------: | :----------------------------------------------------------: |
|
||||
| Label.txt | 检测标签,可直接用于PPOCR检测模型训练。用户每保存5张检测结果后,程序会进行自动写入。当用户关闭应用程序或切换文件路径后同样会进行写入。 |
|
||||
| Label.txt | 检测标签,可直接用于PPOCR检测模型训练。用户每确认5张检测结果后,程序会进行自动写入。当用户关闭应用程序或切换文件路径后同样会进行写入。 |
|
||||
| fileState.txt | 图片状态标记文件,保存当前文件夹下已经被用户手动确认过的图片名称。 |
|
||||
| Cache.cach | 缓存文件,保存模型自动识别的结果。 |
|
||||
| rec_gt.txt | 识别标签。可直接用于PPOCR识别模型训练。需用户手动点击菜单栏“文件” - "保存识别结果"后产生。 |
|
||||
| rec_gt.txt | 识别标签。可直接用于PPOCR识别模型训练。需用户手动点击菜单栏“文件” - "导出识别结果"后产生。 |
|
||||
| crop_img | 识别数据。按照检测框切割后的图片。与rec_gt.txt同时产生。 |
|
||||
|
||||
## 说明
|
||||
|
@ -120,19 +160,19 @@ python3 PPOCRLabel.py --lang ch
|
|||
|
||||
- 自定义模型:用户可根据[自定义模型代码使用](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/whl.md#%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B),通过修改PPOCRLabel.py中针对[PaddleOCR类的实例化](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/PPOCRLabel.py#L110)替换成自己训练的模型。
|
||||
|
||||
### 保存方式
|
||||
### 导出标记结果
|
||||
|
||||
PPOCRLabel支持三种保存方式:
|
||||
PPOCRLabel支持三种导出方式:
|
||||
|
||||
- 自动保存:点击“文件 - 自动保存标记结果”后,用户每确认过一张图片,程序自动将标记结果写入Label.txt中。若未开启此选项,则检测到用户手动确认过5张图片后进行自动保存。
|
||||
- 手动保存:点击“文件 - 保存标记结果”手动保存标记。
|
||||
- 关闭应用程序保存
|
||||
- 自动导出:点击“文件 - 自动导出标记结果”后,用户每确认过一张图片,程序自动将标记结果写入Label.txt中。若未开启此选项,则检测到用户手动确认过5张图片后进行自动导出。
|
||||
- 手动导出:点击“文件 - 导出标记结果”手动导出标记。
|
||||
- 关闭应用程序导出
|
||||
|
||||
### 导出部分识别结果
|
||||
|
||||
针对部分难以识别的数据,通过在识别结果的复选框中**取消勾选**相应的标记,其识别结果不会被导出。
|
||||
|
||||
*注意:识别结果中的复选框状态仍需用户手动点击保存后才能保留*
|
||||
*注意:识别结果中的复选框状态仍需用户手动点击确认后才能保留*
|
||||
|
||||
### 错误提示
|
||||
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
|
||||
|
|
|
@ -23,6 +23,7 @@ except ImportError:
|
|||
|
||||
from libs.shape import Shape
|
||||
from libs.utils import distance
|
||||
import copy
|
||||
|
||||
CURSOR_DEFAULT = Qt.ArrowCursor
|
||||
CURSOR_POINT = Qt.PointingHandCursor
|
||||
|
@ -81,6 +82,7 @@ class Canvas(QWidget):
|
|||
self.fourpoint = True # ADD
|
||||
self.pointnum = 0
|
||||
self.movingShape = False
|
||||
self.selectCountShape = False
|
||||
|
||||
#initialisation for panning
|
||||
self.pan_initial_pos = QPoint()
|
||||
|
@ -702,6 +704,10 @@ class Canvas(QWidget):
|
|||
|
||||
def keyPressEvent(self, ev):
|
||||
key = ev.key()
|
||||
shapesBackup = []
|
||||
shapesBackup = copy.deepcopy(self.shapes)
|
||||
self.shapesBackups.pop()
|
||||
self.shapesBackups.append(shapesBackup)
|
||||
if key == Qt.Key_Escape and self.current:
|
||||
print('ESC press')
|
||||
self.current = None
|
||||
|
@ -709,17 +715,21 @@ class Canvas(QWidget):
|
|||
self.update()
|
||||
elif key == Qt.Key_Return and self.canCloseShape():
|
||||
self.finalise()
|
||||
elif key == Qt.Key_Left and self.selectedShape:
|
||||
elif key == Qt.Key_Left and self.selectedShapes:
|
||||
self.moveOnePixel('Left')
|
||||
elif key == Qt.Key_Right and self.selectedShape:
|
||||
elif key == Qt.Key_Right and self.selectedShapes:
|
||||
self.moveOnePixel('Right')
|
||||
elif key == Qt.Key_Up and self.selectedShape:
|
||||
elif key == Qt.Key_Up and self.selectedShapes:
|
||||
self.moveOnePixel('Up')
|
||||
elif key == Qt.Key_Down and self.selectedShape:
|
||||
elif key == Qt.Key_Down and self.selectedShapes:
|
||||
self.moveOnePixel('Down')
|
||||
|
||||
def moveOnePixel(self, direction):
|
||||
# print(self.selectedShape.points)
|
||||
self.selectCount = len(self.selectedShapes)
|
||||
self.selectCountShape = True
|
||||
for i in range(len(self.selectedShapes)):
|
||||
self.selectedShape = self.selectedShapes[i]
|
||||
if direction == 'Left' and not self.moveOutOfBound(QPointF(-1.0, 0)):
|
||||
# print("move Left one pixel")
|
||||
self.selectedShape.points[0] += QPointF(-1.0, 0)
|
||||
|
@ -744,6 +754,9 @@ class Canvas(QWidget):
|
|||
self.selectedShape.points[1] += QPointF(0, 1.0)
|
||||
self.selectedShape.points[2] += QPointF(0, 1.0)
|
||||
self.selectedShape.points[3] += QPointF(0, 1.0)
|
||||
shapesBackup = []
|
||||
shapesBackup = copy.deepcopy(self.shapes)
|
||||
self.shapesBackups.append(shapesBackup)
|
||||
self.shapeMoved.emit()
|
||||
self.repaint()
|
||||
|
||||
|
@ -840,6 +853,7 @@ class Canvas(QWidget):
|
|||
def restoreShape(self):
|
||||
if not self.isShapeRestorable:
|
||||
return
|
||||
|
||||
self.shapesBackups.pop() # latest
|
||||
shapesBackup = self.shapesBackups.pop()
|
||||
self.shapes = shapesBackup
|
||||
|
|
|
@ -174,6 +174,7 @@ def stepsInfo(lang='en'):
|
|||
"10. 标注结果:关闭应用程序或切换文件路径后,手动保存过的标签将会被存放在所打开图片文件夹下的" \
|
||||
"*Label.txt*中。在菜单栏点击 “PaddleOCR” - 保存识别结果后,会将此类图片的识别训练数据保存在*crop_img*文件夹下," \
|
||||
"识别标签保存在*rec_gt.txt*中。\n"
|
||||
|
||||
else:
|
||||
msg = "1. Build and launch using the instructions above.\n" \
|
||||
"2. Click 'Open Dir' in Menu/File to select the folder of the picture.\n"\
|
||||
|
@ -188,4 +189,56 @@ def stepsInfo(lang='en'):
|
|||
"9. Click 'Delete Image' and the image will be deleted to the recycle bin.\n"\
|
||||
"10. Labeling result: After closing the application or switching the file path, the manually saved label will be stored in *Label.txt* under the opened picture folder.\n"\
|
||||
" Click PaddleOCR-Save Recognition Results in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*.\n"
|
||||
|
||||
return msg
|
||||
|
||||
def keysInfo(lang='en'):
|
||||
if lang == 'ch':
|
||||
msg = "快捷键\t\t\t说明\n" \
|
||||
"———————————————————————\n"\
|
||||
"Ctrl + shift + R\t\t对当前图片的所有标记重新识别\n" \
|
||||
"W\t\t\t新建矩形框\n" \
|
||||
"Q\t\t\t新建四点框\n" \
|
||||
"Ctrl + E\t\t编辑所选框标签\n" \
|
||||
"Ctrl + R\t\t重新识别所选标记\n" \
|
||||
"Ctrl + C\t\t复制并粘贴选中的标记框\n" \
|
||||
"Ctrl + 鼠标左键\t\t多选标记框\n" \
|
||||
"Backspace\t\t删除所选框\n" \
|
||||
"Ctrl + V\t\t确认本张图片标记\n" \
|
||||
"Ctrl + Shift + d\t删除本张图片\n" \
|
||||
"D\t\t\t下一张图片\n" \
|
||||
"A\t\t\t上一张图片\n" \
|
||||
"Ctrl++\t\t\t缩小\n" \
|
||||
"Ctrl--\t\t\t放大\n" \
|
||||
"↑→↓←\t\t\t移动标记框\n" \
|
||||
"———————————————————————\n" \
|
||||
"注:Mac用户Command键替换上述Ctrl键"
|
||||
|
||||
else:
|
||||
msg = "Shortcut Keys\t\tDescription\n" \
|
||||
"———————————————————————\n" \
|
||||
"Ctrl + shift + R\t\tRe-recognize all the labels\n" \
|
||||
"\t\t\tof the current image\n" \
|
||||
"\n"\
|
||||
"W\t\t\tCreate a rect box\n" \
|
||||
"Q\t\t\tCreate a four-points box\n" \
|
||||
"Ctrl + E\t\tEdit label of the selected box\n" \
|
||||
"Ctrl + R\t\tRe-recognize the selected box\n" \
|
||||
"Ctrl + C\t\tCopy and paste the selected\n" \
|
||||
"\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Ctrl + Left Mouse\tMulti select the label\n" \
|
||||
"Button\t\t\tbox\n" \
|
||||
"\n"\
|
||||
"Backspace\t\tDelete the selected box\n" \
|
||||
"Ctrl + V\t\tCheck image\n" \
|
||||
"Ctrl + Shift + d\tDelete image\n" \
|
||||
"D\t\t\tNext image\n" \
|
||||
"A\t\t\tPrevious image\n" \
|
||||
"Ctrl++\t\t\tZoom in\n" \
|
||||
"Ctrl--\t\t\tZoom out\n" \
|
||||
"↑→↓←\t\t\tMove selected box" \
|
||||
"———————————————————————\n" \
|
||||
"Notice:For Mac users, use the 'Command' key instead of the 'Ctrl' key"
|
||||
|
||||
return msg
|
|
@ -18,6 +18,8 @@
|
|||
<file alias="quit">resources/icons/quit.png</file>
|
||||
<file alias="copy">resources/icons/copy.png</file>
|
||||
<file alias="edit">resources/icons/edit.png</file>
|
||||
<file alias="rotateLeft">resources/icons/rotateLeft.png</file>
|
||||
<file alias="rotateRight">resources/icons/rotateRight.png</file>
|
||||
<file alias="open">resources/icons/open.png</file>
|
||||
<file alias="save">resources/icons/save.png</file>
|
||||
<file alias="format_voc">resources/icons/format_voc.png</file>
|
||||
|
|
After Width: | Height: | Size: 4.1 KiB |
After Width: | Height: | Size: 4.1 KiB |
|
@ -31,6 +31,7 @@ save=确认
|
|||
saveAs=另存为
|
||||
fitWinDetail=缩放到当前窗口大小
|
||||
openDir=打开目录
|
||||
openDatasetDir=打开数据集路径
|
||||
copyPrevBounding=复制当前图像中的上一个边界框
|
||||
showHide=显示/隐藏标签
|
||||
changeSaveFormat=更改存储格式
|
||||
|
@ -85,19 +86,22 @@ detectionBoxposition=检测框位置
|
|||
recognitionResult=识别结果
|
||||
creatPolygon=四点标注
|
||||
drawSquares=正方形标注
|
||||
saveRec=保存识别结果
|
||||
rotateLeft=图片左旋转90度
|
||||
rotateRight=图片右旋转90度
|
||||
saveRec=导出识别结果
|
||||
tempLabel=待识别
|
||||
nullLabel=无法识别
|
||||
steps=操作步骤
|
||||
keys=快捷键
|
||||
choseModelLg=选择模型语言
|
||||
cancel=取消
|
||||
ok=确认
|
||||
autolabeling=自动标注中
|
||||
hideBox=隐藏所有标注
|
||||
showBox=显示所有标注
|
||||
saveLabel=保存标记结果
|
||||
saveLabel=导出标记结果
|
||||
singleRe=重识别此区块
|
||||
labelDialogOption=弹出标记输入框
|
||||
undo=撤销
|
||||
undoLastPoint=撤销上个点
|
||||
autoSaveMode=自动保存标记结果
|
||||
autoSaveMode=自动导出标记结果
|
|
@ -3,6 +3,7 @@ openFileDetail=Open image or label file
|
|||
quit=Quit
|
||||
quitApp=Quit application
|
||||
openDir=Open Dir
|
||||
openDatasetDir=Open DatasetDir
|
||||
copyPrevBounding=Copy previous Bounding Boxes in the current image
|
||||
changeSavedAnnotationDir=Change default saved Annotation dir
|
||||
openAnnotation=Open Annotation
|
||||
|
@ -84,20 +85,23 @@ iconList=Icon List
|
|||
detectionBoxposition=Detection box position
|
||||
recognitionResult=Recognition result
|
||||
creatPolygon=Create Quadrilateral
|
||||
rotateLeft=Left turn 90 degrees
|
||||
rotateRight=Right turn 90 degrees
|
||||
drawSquares=Draw Squares
|
||||
saveRec=Save Recognition Result
|
||||
saveRec=Export Recognition Result
|
||||
tempLabel=TEMPORARY
|
||||
nullLabel=NULL
|
||||
steps=Steps
|
||||
keys=Shortcut Keys
|
||||
choseModelLg=Choose Model Language
|
||||
cancel=Cancel
|
||||
ok=OK
|
||||
autolabeling=Automatic Labeling
|
||||
hideBox=Hide All Box
|
||||
showBox=Show All Box
|
||||
saveLabel=Save Label
|
||||
saveLabel=Export Label
|
||||
singleRe=Re-recognition RectBox
|
||||
labelDialogOption=Pop-up Label Input Dialog
|
||||
undo=Undo
|
||||
undoLastPoint=Undo Last Point
|
||||
autoSaveMode=Auto Save Label Mode
|
||||
autoSaveMode=Auto Export Label Mode
|
|
@ -66,6 +66,7 @@ class StdTextDrawer(object):
|
|||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
i = 0
|
||||
break
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
|
@ -78,7 +79,6 @@ class StdTextDrawer(object):
|
|||
|
||||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
break
|
||||
|
||||
return corpus_list, text_input_list
|
||||
|
|
|
@ -11,7 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import paddleocr
|
||||
from .paddleocr import *
|
||||
|
||||
__all__ = ['PaddleOCR', 'draw_ocr']
|
||||
from .paddleocr import PaddleOCR
|
||||
from .tools.infer.utility import draw_ocr
|
||||
__version__ = paddleocr.VERSION
|
||||
__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
|
||||
|
|
|
@ -0,0 +1,202 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
- ["Student2", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
# act: None
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
# name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2", "Teacher"]
|
||||
# key: maps
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,174 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Teacher:
|
||||
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
|
||||
freeze_params: true
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 18
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 256
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDilaDBLoss:
|
||||
weight: 1.0
|
||||
model_name_pairs:
|
||||
- ["Student", "Teacher"]
|
||||
key: maps
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Teacher"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -0,0 +1,176 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 1200
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 2
|
||||
save_model_dir: ./output/ch_db_mv3/
|
||||
save_epoch_step: 1200
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [3000, 2000]
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_en/img_10.jpg
|
||||
save_res_path: ./output/det_db/predicts_db.txt
|
||||
|
||||
Architecture:
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
Student2:
|
||||
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
|
||||
freeze_params: false
|
||||
return_all_feats: false
|
||||
model_type: det
|
||||
algorithm: DB
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: large
|
||||
disable_se: True
|
||||
Neck:
|
||||
name: DBFPN
|
||||
out_channels: 96
|
||||
Head:
|
||||
name: DBHead
|
||||
k: 50
|
||||
|
||||
|
||||
Loss:
|
||||
name: CombinedLoss
|
||||
loss_config_list:
|
||||
- DistillationDMLLoss:
|
||||
model_name_pairs:
|
||||
- ["Student", "Student2"]
|
||||
maps_name: "thrink_maps"
|
||||
weight: 1.0
|
||||
act: "softmax"
|
||||
model_name_pairs: ["Student", "Student2"]
|
||||
key: maps
|
||||
- DistillationDBLoss:
|
||||
weight: 1.0
|
||||
model_name_list: ["Student", "Student2"]
|
||||
# key: maps
|
||||
name: DBLoss
|
||||
balance_loss: true
|
||||
main_loss_type: DiceLoss
|
||||
alpha: 5
|
||||
beta: 10
|
||||
ohem_ratio: 3
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
warmup_epoch: 2
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
PostProcess:
|
||||
name: DistillationDBPostProcess
|
||||
model_name: ["Student", "Student2"]
|
||||
key: head_out
|
||||
thresh: 0.3
|
||||
box_thresh: 0.6
|
||||
max_candidates: 1000
|
||||
unclip_ratio: 1.5
|
||||
|
||||
Metric:
|
||||
name: DistillationMetric
|
||||
base_metric_name: DetMetric
|
||||
main_indicator: hmean
|
||||
key: "Student"
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
|
||||
ratio_list: [1.0]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- IaaAugment:
|
||||
augmenter_args:
|
||||
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
|
||||
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
|
||||
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
|
||||
- EastRandomCropData:
|
||||
size: [960, 960]
|
||||
max_tries: 50
|
||||
keep_ratio: true
|
||||
- MakeBorderMap:
|
||||
shrink_ratio: 0.4
|
||||
thresh_min: 0.3
|
||||
thresh_max: 0.7
|
||||
- MakeShrinkMap:
|
||||
shrink_ratio: 0.4
|
||||
min_text_size: 8
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: False
|
||||
batch_size_per_card: 8
|
||||
num_workers: 4
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/icdar2015/text_localization/
|
||||
label_file_list:
|
||||
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- DetLabelEncode: # Class handling label
|
||||
- DetResizeForTest:
|
||||
# image_shape: [736, 1280]
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -17,7 +17,7 @@ Global:
|
|||
character_type: ch
|
||||
max_text_length: 25
|
||||
infer_mode: false
|
||||
use_space_char: false
|
||||
use_space_char: true
|
||||
distributed: true
|
||||
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
|
||||
|
||||
|
@ -27,54 +27,55 @@ Optimizer:
|
|||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.0005
|
||||
name: Piecewise
|
||||
decay_epochs : [700, 800]
|
||||
values : [0.001, 0.0001]
|
||||
warmup_epoch: 5
|
||||
regularizer:
|
||||
name: L2
|
||||
factor: 1.0e-05
|
||||
factor: 2.0e-05
|
||||
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: rec
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
|
||||
|
||||
Loss:
|
||||
|
|
|
@ -10,7 +10,7 @@ Global:
|
|||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
save_inference_dir: ./
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words_en/word_10.png
|
||||
# for data or label process
|
||||
|
@ -60,8 +60,8 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list: ["./train_data/train_list.txt"]
|
||||
data_dir: ./train_data/ic15_data/
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_train.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -81,8 +81,8 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list: ["./train_data/val_list.txt"]
|
||||
data_dir: ./train_data/ic15_data
|
||||
label_file_list: ["./train_data/ic15_data/rec_gt_test.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
|
|
@ -0,0 +1,116 @@
|
|||
Global:
|
||||
use_gpu: true
|
||||
epoch_num: 50
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 5
|
||||
save_model_dir: ./output/table_mv3/
|
||||
save_epoch_step: 5
|
||||
# evaluation is run every 400 iterations after the 0th iteration
|
||||
eval_batch_step: [0, 400]
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img: doc/imgs_words/ch/word_1.jpg
|
||||
# for data or label process
|
||||
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
|
||||
character_type: en
|
||||
max_text_length: 100
|
||||
max_elem_length: 500
|
||||
max_cell_num: 500
|
||||
infer_mode: False
|
||||
process_total_num: 0
|
||||
process_cut_num: 0
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
clip_norm: 5.0
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00000
|
||||
|
||||
Architecture:
|
||||
model_type: table
|
||||
algorithm: TableAttn
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 1.0
|
||||
model_name: small
|
||||
disable_se: True
|
||||
Head:
|
||||
name: TableAttentionHead
|
||||
hidden_size: 256
|
||||
l2_decay: 0.00001
|
||||
loc_type: 2
|
||||
|
||||
Loss:
|
||||
name: TableAttentionLoss
|
||||
structure_weight: 100.0
|
||||
loc_weight: 10000.0
|
||||
|
||||
PostProcess:
|
||||
name: TableLabelDecode
|
||||
|
||||
Metric:
|
||||
name: TableMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/train/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 32
|
||||
drop_last: True
|
||||
num_workers: 1
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PubTabDataSet
|
||||
data_dir: train_data/table/pubtabnet/val/
|
||||
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- ResizeTableImage:
|
||||
max_len: 488
|
||||
- TableLabelEncode:
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: 'hwc'
|
||||
- PaddingTableImage:
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 16
|
||||
num_workers: 1
|
|
@ -1,4 +1,4 @@
|
|||
project(ocr_system CXX C)
|
||||
project(ppocr CXX C)
|
||||
|
||||
option(WITH_MKL "Compile demo with MKL/OpenBlas support, default use MKL." ON)
|
||||
option(WITH_GPU "Compile demo with GPU/CPU, default use CPU." OFF)
|
||||
|
@ -11,7 +11,7 @@ SET(CUDA_LIB "" CACHE PATH "Location of libraries")
|
|||
SET(CUDNN_LIB "" CACHE PATH "Location of libraries")
|
||||
SET(TENSORRT_DIR "" CACHE PATH "Compile demo with TensorRT")
|
||||
|
||||
set(DEMO_NAME "ocr_system")
|
||||
set(DEMO_NAME "ppocr")
|
||||
|
||||
|
||||
macro(safe_set_static_flag)
|
||||
|
@ -38,10 +38,8 @@ endif()
|
|||
|
||||
|
||||
if (WIN32)
|
||||
include_directories("${PADDLE_LIB}/paddle/fluid/inference")
|
||||
include_directories("${PADDLE_LIB}/paddle/include")
|
||||
link_directories("${PADDLE_LIB}/paddle/lib")
|
||||
link_directories("${PADDLE_LIB}/paddle/fluid/inference")
|
||||
find_package(OpenCV REQUIRED PATHS ${OPENCV_DIR}/build/ NO_DEFAULT_PATH)
|
||||
|
||||
else ()
|
||||
|
|
|
@ -14,7 +14,7 @@ PaddleOCR在Windows 平台下基于`Visual Studio 2019 Community` 进行了测
|
|||
|
||||
### Step1: 下载PaddlePaddle C++ 预测库 fluid_inference
|
||||
|
||||
PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/windows_cpp_inference.html)
|
||||
PaddlePaddle C++ 预测库针对不同的`CPU`和`CUDA`版本提供了不同的预编译版本,请根据实际情况下载: [C++预测库下载列表](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#windows)
|
||||
|
||||
解压后`D:\projects\fluid_inference`目录包含内容为:
|
||||
```
|
||||
|
@ -93,3 +93,5 @@ cd D:\projects\PaddleOCR\deploy\cpp_infer\out\build\x64-Release
|
|||
|
||||
### 注意
|
||||
* 在Windows下的终端中执行文件exe时,可能会发生乱码的现象,此时需要在终端中输入`CHCP 65001`,将终端的编码方式由GBK编码(默认)改为UTF-8编码,更加具体的解释可以参考这篇博客:[https://blog.csdn.net/qq_35038153/article/details/78430359](https://blog.csdn.net/qq_35038153/article/details/78430359)。
|
||||
|
||||
* 编译时,如果报错`错误:C1083 无法打开包括文件:"dirent.h":No such file or directory`,可以参考该[文档](https://blog.csdn.net/Dora_blank/article/details/117740837#41_C1083_direnthNo_such_file_or_directory_54),新建`dirent.h`文件,并添加到`VC++`的包含目录中。
|
||||
|
|
|
@ -668,7 +668,7 @@ void DisposeOutPts(OutPt *&pp) {
|
|||
//------------------------------------------------------------------------------
|
||||
|
||||
inline void InitEdge(TEdge *e, TEdge *eNext, TEdge *ePrev, const IntPoint &Pt) {
|
||||
std::memset(e, 0, sizeof(TEdge));
|
||||
std::memset(e, int(0), sizeof(TEdge));
|
||||
e->Next = eNext;
|
||||
e->Prev = ePrev;
|
||||
e->Curr = Pt;
|
||||
|
@ -1895,17 +1895,17 @@ void Clipper::InsertLocalMinimaIntoAEL(const cInt botY) {
|
|||
TEdge *rb = lm->RightBound;
|
||||
|
||||
OutPt *Op1 = 0;
|
||||
if (!lb) {
|
||||
if (!lb || !rb) {
|
||||
// nb: don't insert LB into either AEL or SEL
|
||||
InsertEdgeIntoAEL(rb, 0);
|
||||
SetWindingCount(*rb);
|
||||
if (IsContributing(*rb))
|
||||
Op1 = AddOutPt(rb, rb->Bot);
|
||||
} else if (!rb) {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
SetWindingCount(*lb);
|
||||
if (IsContributing(*lb))
|
||||
Op1 = AddOutPt(lb, lb->Bot);
|
||||
//} else if (!rb) {
|
||||
// InsertEdgeIntoAEL(lb, 0);
|
||||
// SetWindingCount(*lb);
|
||||
// if (IsContributing(*lb))
|
||||
// Op1 = AddOutPt(lb, lb->Bot);
|
||||
InsertScanbeam(lb->Top.Y);
|
||||
} else {
|
||||
InsertEdgeIntoAEL(lb, 0);
|
||||
|
@ -2547,13 +2547,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
if (dir == dLeftToRight) {
|
||||
maxIt = m_Maxima.begin();
|
||||
while (maxIt != m_Maxima.end() && *maxIt <= horzEdge->Bot.X)
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
if (maxIt != m_Maxima.end() && *maxIt >= eLastHorz->Top.X)
|
||||
maxIt = m_Maxima.end();
|
||||
} else {
|
||||
maxRit = m_Maxima.rbegin();
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > horzEdge->Bot.X)
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
if (maxRit != m_Maxima.rend() && *maxRit <= eLastHorz->Top.X)
|
||||
maxRit = m_Maxima.rend();
|
||||
}
|
||||
|
@ -2576,13 +2576,13 @@ void Clipper::ProcessHorizontal(TEdge *horzEdge) {
|
|||
while (maxIt != m_Maxima.end() && *maxIt < e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxIt, horzEdge->Bot.Y));
|
||||
maxIt++;
|
||||
++maxIt;
|
||||
}
|
||||
} else {
|
||||
while (maxRit != m_Maxima.rend() && *maxRit > e->Curr.X) {
|
||||
if (horzEdge->OutIdx >= 0 && !IsOpen)
|
||||
AddOutPt(horzEdge, IntPoint(*maxRit, horzEdge->Bot.Y));
|
||||
maxRit++;
|
||||
++maxRit;
|
||||
}
|
||||
}
|
||||
};
|
|
@ -31,6 +31,8 @@
|
|||
* *
|
||||
*******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef clipper_hpp
|
||||
#define clipper_hpp
|
||||
|
||||
|
|
|
@ -1,123 +0,0 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "include/utility.h"
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
class OCRConfig {
|
||||
public:
|
||||
explicit OCRConfig(const std::string &config_file) {
|
||||
config_map_ = LoadConfig(config_file);
|
||||
|
||||
this->use_gpu = bool(stoi(config_map_["use_gpu"]));
|
||||
|
||||
this->gpu_id = stoi(config_map_["gpu_id"]);
|
||||
|
||||
this->gpu_mem = stoi(config_map_["gpu_mem"]);
|
||||
|
||||
this->cpu_math_library_num_threads =
|
||||
stoi(config_map_["cpu_math_library_num_threads"]);
|
||||
|
||||
this->use_mkldnn = bool(stoi(config_map_["use_mkldnn"]));
|
||||
|
||||
this->max_side_len = stoi(config_map_["max_side_len"]);
|
||||
|
||||
this->det_db_thresh = stod(config_map_["det_db_thresh"]);
|
||||
|
||||
this->det_db_box_thresh = stod(config_map_["det_db_box_thresh"]);
|
||||
|
||||
this->det_db_unclip_ratio = stod(config_map_["det_db_unclip_ratio"]);
|
||||
|
||||
this->use_polygon_score = bool(stoi(config_map_["use_polygon_score"]));
|
||||
|
||||
this->det_model_dir.assign(config_map_["det_model_dir"]);
|
||||
|
||||
this->rec_model_dir.assign(config_map_["rec_model_dir"]);
|
||||
|
||||
this->char_list_file.assign(config_map_["char_list_file"]);
|
||||
|
||||
this->use_angle_cls = bool(stoi(config_map_["use_angle_cls"]));
|
||||
|
||||
this->cls_model_dir.assign(config_map_["cls_model_dir"]);
|
||||
|
||||
this->cls_thresh = stod(config_map_["cls_thresh"]);
|
||||
|
||||
this->visualize = bool(stoi(config_map_["visualize"]));
|
||||
|
||||
this->use_tensorrt = bool(stoi(config_map_["use_tensorrt"]));
|
||||
|
||||
this->use_fp16 = bool(stod(config_map_["use_fp16"]));
|
||||
}
|
||||
|
||||
bool use_gpu = false;
|
||||
|
||||
int gpu_id = 0;
|
||||
|
||||
int gpu_mem = 4000;
|
||||
|
||||
int cpu_math_library_num_threads = 1;
|
||||
|
||||
bool use_mkldnn = false;
|
||||
|
||||
int max_side_len = 960;
|
||||
|
||||
double det_db_thresh = 0.3;
|
||||
|
||||
double det_db_box_thresh = 0.5;
|
||||
|
||||
double det_db_unclip_ratio = 2.0;
|
||||
|
||||
bool use_polygon_score = false;
|
||||
|
||||
std::string det_model_dir;
|
||||
|
||||
std::string rec_model_dir;
|
||||
|
||||
bool use_angle_cls;
|
||||
|
||||
std::string char_list_file;
|
||||
|
||||
std::string cls_model_dir;
|
||||
|
||||
double cls_thresh;
|
||||
|
||||
bool visualize = true;
|
||||
|
||||
bool use_tensorrt = false;
|
||||
|
||||
bool use_fp16 = false;
|
||||
|
||||
void PrintConfigInfo();
|
||||
|
||||
private:
|
||||
// Load configuration
|
||||
std::map<std::string, std::string> LoadConfig(const std::string &config_file);
|
||||
|
||||
std::vector<std::string> split(const std::string &str,
|
||||
const std::string &delim);
|
||||
|
||||
std::map<std::string, std::string> config_map_;
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
@ -40,7 +42,7 @@ public:
|
|||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const double &cls_thresh,
|
||||
const bool &use_tensorrt, const bool &use_fp16) {
|
||||
const bool &use_tensorrt, const std::string &precision) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
|
@ -49,7 +51,7 @@ public:
|
|||
|
||||
this->cls_thresh = cls_thresh;
|
||||
this->use_tensorrt_ = use_tensorrt;
|
||||
this->use_fp16_ = use_fp16;
|
||||
this->precision_ = precision;
|
||||
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
@ -73,7 +75,7 @@ private:
|
|||
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
bool is_scale_ = true;
|
||||
bool use_tensorrt_ = false;
|
||||
bool use_fp16_ = false;
|
||||
std::string precision_ = "fp32";
|
||||
// pre-process
|
||||
ClsResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
|
|
|
@ -46,7 +46,7 @@ public:
|
|||
const double &det_db_box_thresh,
|
||||
const double &det_db_unclip_ratio,
|
||||
const bool &use_polygon_score, const bool &visualize,
|
||||
const bool &use_tensorrt, const bool &use_fp16) {
|
||||
const bool &use_tensorrt, const std::string &precision) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
|
@ -62,7 +62,7 @@ public:
|
|||
|
||||
this->visualize_ = visualize;
|
||||
this->use_tensorrt_ = use_tensorrt;
|
||||
this->use_fp16_ = use_fp16;
|
||||
this->precision_ = precision;
|
||||
|
||||
LoadModel(model_dir);
|
||||
}
|
||||
|
@ -71,7 +71,7 @@ public:
|
|||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
// Run predictor
|
||||
void Run(cv::Mat &img, std::vector<std::vector<std::vector<int>>> &boxes);
|
||||
void Run(cv::Mat &img, std::vector<std::vector<std::vector<int>>> &boxes, std::vector<double> *times);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
@ -91,7 +91,7 @@ private:
|
|||
|
||||
bool visualize_ = true;
|
||||
bool use_tensorrt_ = false;
|
||||
bool use_fp16_ = false;
|
||||
std::string precision_ = "fp32";
|
||||
|
||||
std::vector<float> mean_ = {0.485f, 0.456f, 0.406f};
|
||||
std::vector<float> scale_ = {1 / 0.229f, 1 / 0.224f, 1 / 0.225f};
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opencv2/core.hpp"
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
#include "opencv2/imgproc.hpp"
|
||||
|
@ -42,14 +44,14 @@ public:
|
|||
const int &gpu_id, const int &gpu_mem,
|
||||
const int &cpu_math_library_num_threads,
|
||||
const bool &use_mkldnn, const string &label_path,
|
||||
const bool &use_tensorrt, const bool &use_fp16) {
|
||||
const bool &use_tensorrt, const std::string &precision) {
|
||||
this->use_gpu_ = use_gpu;
|
||||
this->gpu_id_ = gpu_id;
|
||||
this->gpu_mem_ = gpu_mem;
|
||||
this->cpu_math_library_num_threads_ = cpu_math_library_num_threads;
|
||||
this->use_mkldnn_ = use_mkldnn;
|
||||
this->use_tensorrt_ = use_tensorrt;
|
||||
this->use_fp16_ = use_fp16;
|
||||
this->precision_ = precision;
|
||||
|
||||
this->label_list_ = Utility::ReadDict(label_path);
|
||||
this->label_list_.insert(this->label_list_.begin(),
|
||||
|
@ -62,8 +64,7 @@ public:
|
|||
// Load Paddle inference model
|
||||
void LoadModel(const std::string &model_dir);
|
||||
|
||||
void Run(std::vector<std::vector<std::vector<int>>> boxes, cv::Mat &img,
|
||||
Classifier *cls);
|
||||
void Run(cv::Mat &img, std::vector<double> *times);
|
||||
|
||||
private:
|
||||
std::shared_ptr<Predictor> predictor_;
|
||||
|
@ -80,7 +81,7 @@ private:
|
|||
std::vector<float> scale_ = {1 / 0.5f, 1 / 0.5f, 1 / 0.5f};
|
||||
bool is_scale_ = true;
|
||||
bool use_tensorrt_ = false;
|
||||
bool use_fp16_ = false;
|
||||
std::string precision_ = "fp32";
|
||||
// pre-process
|
||||
CrnnResizeImg resize_op_;
|
||||
Normalize normalize_op_;
|
||||
|
@ -89,9 +90,6 @@ private:
|
|||
// post-process
|
||||
PostProcessor post_processor_;
|
||||
|
||||
cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
|
||||
std::vector<std::vector<int>> box);
|
||||
|
||||
}; // class CrnnRecognizer
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -47,6 +47,9 @@ public:
|
|||
|
||||
static void GetAllFiles(const char *dir_name,
|
||||
std::vector<std::string> &all_inputs);
|
||||
|
||||
static cv::Mat GetRotateCropImage(const cv::Mat &srcimage,
|
||||
std::vector<std::vector<int>> box);
|
||||
};
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -18,6 +18,7 @@ PaddleOCR模型部署。
|
|||
* 首先需要从opencv官网上下载在Linux环境下源码编译的包,以opencv3.4.7为例,下载命令如下。
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
@ -153,82 +154,102 @@ inference/
|
|||
|
||||
* 编译命令如下,其中Paddle C++预测库、opencv等其他依赖库的地址需要换成自己机器上的实际地址。
|
||||
|
||||
|
||||
```shell
|
||||
sh tools/build.sh
|
||||
```
|
||||
|
||||
具体地,`tools/build.sh`中内容如下。
|
||||
* 具体的,需要修改`tools/build.sh`中环境路径,相关内容如下:
|
||||
|
||||
```shell
|
||||
OPENCV_DIR=your_opencv_dir
|
||||
LIB_DIR=your_paddle_inference_dir
|
||||
CUDA_LIB_DIR=your_cuda_lib_dir
|
||||
CUDNN_LIB_DIR=/your_cudnn_lib_dir
|
||||
|
||||
BUILD_DIR=build
|
||||
rm -rf ${BUILD_DIR}
|
||||
mkdir ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
cmake .. \
|
||||
-DPADDLE_LIB=${LIB_DIR} \
|
||||
-DWITH_MKL=ON \
|
||||
-DDEMO_NAME=ocr_system \
|
||||
-DWITH_GPU=OFF \
|
||||
-DWITH_STATIC_LIB=OFF \
|
||||
-DUSE_TENSORRT=OFF \
|
||||
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||
-DCUDNN_LIB=${CUDNN_LIB_DIR} \
|
||||
-DCUDA_LIB=${CUDA_LIB_DIR} \
|
||||
|
||||
make -j
|
||||
```
|
||||
|
||||
`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。
|
||||
其中,`OPENCV_DIR`为opencv编译安装的地址;`LIB_DIR`为下载(`paddle_inference`文件夹)或者编译生成的Paddle预测库地址(`build/paddle_inference_install_dir`文件夹);`CUDA_LIB_DIR`为cuda库文件地址,在docker中为`/usr/local/cuda/lib64`;`CUDNN_LIB_DIR`为cudnn库文件地址,在docker中为`/usr/lib/x86_64-linux-gnu/`。**注意:以上路径都写绝对路径,不要写相对路径。**
|
||||
|
||||
|
||||
* 编译完成之后,会在`build`文件夹下生成一个名为`ocr_system`的可执行文件。
|
||||
* 编译完成之后,会在`build`文件夹下生成一个名为`ppocr`的可执行文件。
|
||||
|
||||
|
||||
### 运行demo
|
||||
* 执行以下命令,完成对一幅图像的OCR识别与检测。
|
||||
|
||||
运行方式:
|
||||
```shell
|
||||
sh tools/run.sh
|
||||
./build/ppocr <mode> [--param1] [--param2] [...]
|
||||
```
|
||||
其中,`mode`为必选参数,表示选择的功能,取值范围['det', 'rec', 'system'],分别表示调用检测、识别、检测识别串联(包括方向分类器)。具体命令如下:
|
||||
|
||||
##### 1. 只调用检测:
|
||||
```shell
|
||||
./build/ppocr det \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
```
|
||||
##### 2. 只调用识别:
|
||||
```shell
|
||||
./build/ppocr rec \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs_words/ch/
|
||||
```
|
||||
##### 3. 调用串联:
|
||||
```shell
|
||||
# 不使用方向分类器
|
||||
./build/ppocr system \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
# 使用方向分类器
|
||||
./build/ppocr system \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--use_angle_cls=true \
|
||||
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
```
|
||||
|
||||
* 若需要使用方向分类器,则需要将`tools/config.txt`中的`use_angle_cls`参数修改为1,表示开启方向分类器的预测。
|
||||
* 更多地,tools/config.txt中的参数及解释如下。
|
||||
更多参数如下:
|
||||
|
||||
```
|
||||
use_gpu 0 # 是否使用GPU,1表示使用,0表示不使用
|
||||
gpu_id 0 # GPU id,使用GPU时有效
|
||||
gpu_mem 4000 # 申请的GPU内存
|
||||
cpu_math_library_num_threads 10 # CPU预测时的线程数,在机器核数充足的情况下,该值越大,预测速度越快
|
||||
use_mkldnn 1 # 是否使用mkldnn库
|
||||
- 通用参数
|
||||
|
||||
# det config
|
||||
max_side_len 960 # 输入图像长宽大于960时,等比例缩放图像,使得图像最长边为960
|
||||
det_db_thresh 0.3 # 用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显
|
||||
det_db_box_thresh 0.5 # DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小
|
||||
det_db_unclip_ratio 1.6 # 表示文本框的紧致程度,越小则文本框更靠近文本
|
||||
use_polygon_score 1 # 是否使用多边形框计算bbox score,0表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。
|
||||
det_model_dir ./inference/det_db # 检测模型inference model地址
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| --- | --- | --- | --- |
|
||||
|use_gpu|bool|false|是否使用GPU|
|
||||
|gpu_id|int|0|GPU id,使用GPU时有效|
|
||||
|gpu_mem|int|4000|申请的GPU内存|
|
||||
|cpu_math_library_num_threads|int|10|CPU预测时的线程数,在机器核数充足的情况下,该值越大,预测速度越快|
|
||||
|use_mkldnn|bool|true|是否使用mkldnn库|
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0 # 是否使用方向分类器,0表示不使用,1表示使用
|
||||
cls_model_dir ./inference/cls # 方向分类器inference model地址
|
||||
cls_thresh 0.9 # 方向分类器的得分阈值
|
||||
- 检测模型相关
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/rec_crnn # 识别模型inference model地址
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # 字典文件
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| --- | --- | --- | --- |
|
||||
|det_model_dir|string|-|检测模型inference model地址|
|
||||
|max_side_len|int|960|输入图像长宽大于960时,等比例缩放图像,使得图像最长边为960|
|
||||
|det_db_thresh|float|0.3|用于过滤DB预测的二值化图像,设置为0.-0.3对结果影响不明显|
|
||||
|det_db_box_thresh|float|0.5|DB后处理过滤box的阈值,如果检测存在漏框情况,可酌情减小|
|
||||
|det_db_unclip_ratio|float|1.6|表示文本框的紧致程度,越小则文本框更靠近文本|
|
||||
|use_polygon_score|bool|false|是否使用多边形框计算bbox score,false表示使用矩形框计算。矩形框计算速度更快,多边形框对弯曲文本区域计算更准确。|
|
||||
|visualize|bool|true|是否对结果进行可视化,为1时,会在当前文件夹下保存文件名为`ocr_vis.png`的预测结果。|
|
||||
|
||||
# show the detection results
|
||||
visualize 1 # 是否对结果进行可视化,为1时,会在当前文件夹下保存文件名为`ocr_vis.png`的预测结果。
|
||||
```
|
||||
- 方向分类器相关
|
||||
|
||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`tools/config.txt`中的`char_list_file`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| --- | --- | --- | --- |
|
||||
|use_angle_cls|bool|false|是否使用方向分类器|
|
||||
|cls_model_dir|string|-|方向分类器inference model地址|
|
||||
|cls_thresh|float|0.9|方向分类器的得分阈值|
|
||||
|
||||
- 识别模型相关
|
||||
|
||||
|参数名称|类型|默认参数|意义|
|
||||
| --- | --- | --- | --- |
|
||||
|rec_model_dir|string|-|识别模型inference model地址|
|
||||
|char_list_file|string|../../ppocr/utils/ppocr_keys_v1.txt|字典文件|
|
||||
|
||||
|
||||
* PaddleOCR也支持多语言的预测,更多支持的语言和模型可以参考[识别文档](../../doc/doc_ch/recognition.md)中的多语言字典与模型部分,如果希望进行多语言预测,只需将修改`char_list_file`(字典文件路径)以及`rec_model_dir`(inference模型路径)字段即可。
|
||||
|
||||
最终屏幕上会输出检测结果如下。
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ PaddleOCR model deployment.
|
|||
* First of all, you need to download the source code compiled package in the Linux environment from the opencv official website. Taking opencv3.4.7 as an example, the download command is as follows.
|
||||
|
||||
```
|
||||
cd deploy/cpp_infer
|
||||
wget https://github.com/opencv/opencv/archive/3.4.7.tar.gz
|
||||
tar -xf 3.4.7.tar.gz
|
||||
```
|
||||
|
@ -161,30 +162,13 @@ inference/
|
|||
sh tools/build.sh
|
||||
```
|
||||
|
||||
Specifically, the content in `tools/build.sh` is as follows.
|
||||
Specifically, you should modify the paths in `tools/build.sh`. The related content is as follows.
|
||||
|
||||
```shell
|
||||
OPENCV_DIR=your_opencv_dir
|
||||
LIB_DIR=your_paddle_inference_dir
|
||||
CUDA_LIB_DIR=your_cuda_lib_dir
|
||||
CUDNN_LIB_DIR=your_cudnn_lib_dir
|
||||
|
||||
BUILD_DIR=build
|
||||
rm -rf ${BUILD_DIR}
|
||||
mkdir ${BUILD_DIR}
|
||||
cd ${BUILD_DIR}
|
||||
cmake .. \
|
||||
-DPADDLE_LIB=${LIB_DIR} \
|
||||
-DWITH_MKL=ON \
|
||||
-DDEMO_NAME=ocr_system \
|
||||
-DWITH_GPU=OFF \
|
||||
-DWITH_STATIC_LIB=OFF \
|
||||
-DUSE_TENSORRT=OFF \
|
||||
-DOPENCV_DIR=${OPENCV_DIR} \
|
||||
-DCUDNN_LIB=${CUDNN_LIB_DIR} \
|
||||
-DCUDA_LIB=${CUDA_LIB_DIR} \
|
||||
|
||||
make -j
|
||||
```
|
||||
|
||||
`OPENCV_DIR` is the opencv installation path; `LIB_DIR` is the download (`paddle_inference` folder)
|
||||
|
@ -192,48 +176,84 @@ or the generated Paddle inference library path (`build/paddle_inference_install_
|
|||
`CUDA_LIB_DIR` is the cuda library file path, in docker; it is `/usr/local/cuda/lib64`; `CUDNN_LIB_DIR` is the cudnn library file path, in docker it is `/usr/lib/x86_64-linux-gnu/`.
|
||||
|
||||
|
||||
* After the compilation is completed, an executable file named `ocr_system` will be generated in the `build` folder.
|
||||
* After the compilation is completed, an executable file named `ppocr` will be generated in the `build` folder.
|
||||
|
||||
|
||||
### Run the demo
|
||||
* Execute the following command to complete the OCR recognition and detection of an image.
|
||||
|
||||
Execute the built executable file:
|
||||
```shell
|
||||
sh tools/run.sh
|
||||
./build/ppocr <mode> [--param1] [--param2] [...]
|
||||
```
|
||||
Here, `mode` is a required parameter,and the value range is ['det', 'rec', 'system'], representing using detection only, using recognition only and using the end-to-end system respectively. Specifically,
|
||||
|
||||
##### 1. run det demo:
|
||||
```shell
|
||||
./build/ppocr det \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
```
|
||||
##### 2. run rec demo:
|
||||
```shell
|
||||
./build/ppocr rec \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs_words/ch/
|
||||
```
|
||||
##### 3. run system demo:
|
||||
```shell
|
||||
# without text direction classifier
|
||||
./build/ppocr system \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
# with text direction classifier
|
||||
./build/ppocr system \
|
||||
--det_model_dir=inference/ch_ppocr_mobile_v2.0_det_infer \
|
||||
--use_angle_cls=true \
|
||||
--cls_model_dir=inference/ch_ppocr_mobile_v2.0_cls_infer \
|
||||
--rec_model_dir=inference/ch_ppocr_mobile_v2.0_rec_infer \
|
||||
--image_dir=../../doc/imgs/12.jpg
|
||||
```
|
||||
|
||||
* If you want to orientation classifier to correct the detected boxes, you can set `use_angle_cls` in the file `tools/config.txt` as 1 to enable the function.
|
||||
* What's more, Parameters and their meanings in `tools/config.txt` are as follows.
|
||||
More parameters are as follows,
|
||||
|
||||
- common parameters
|
||||
|
||||
```
|
||||
use_gpu 0 # Whether to use GPU, 0 means not to use, 1 means to use
|
||||
gpu_id 0 # GPU id when use_gpu is 1
|
||||
gpu_mem 4000 # GPU memory requested
|
||||
cpu_math_library_num_threads 10 # Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed
|
||||
use_mkldnn 1 # Whether to use mkdlnn library
|
||||
|parameter|data type|default|meaning|
|
||||
| --- | --- | --- | --- |
|
||||
|use_gpu|bool|false|Whether to use GPU|
|
||||
|gpu_id|int|0|GPU id when use_gpu is true|
|
||||
|gpu_mem|int|4000|GPU memory requested|
|
||||
|cpu_math_library_num_threads|int|10|Number of threads when using CPU inference. When machine cores is enough, the large the value, the faster the inference speed|
|
||||
|use_mkldnn|bool|true|Whether to use mkdlnn library|
|
||||
|
||||
max_side_len 960 # Limit the maximum image height and width to 960
|
||||
det_db_thresh 0.3 # Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result
|
||||
det_db_box_thresh 0.5 # DDB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate
|
||||
det_db_unclip_ratio 1.6 # Indicates the compactness of the text box, the smaller the value, the closer the text box to the text
|
||||
use_polygon_score 1 # Whether to use polygon box to calculate bbox score, 0 means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.
|
||||
det_model_dir ./inference/det_db # Address of detection inference model
|
||||
- detection related parameters
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0 # Whether to use the direction classifier, 0 means not to use, 1 means to use
|
||||
cls_model_dir ./inference/cls # Address of direction classifier inference model
|
||||
cls_thresh 0.9 # Score threshold of the direction classifier
|
||||
|parameter|data type|default|meaning|
|
||||
| --- | --- | --- | --- |
|
||||
|det_model_dir|string|-|Address of detection inference model|
|
||||
|max_side_len|int|960|Limit the maximum image height and width to 960|
|
||||
|det_db_thresh|float|0.3|Used to filter the binarized image of DB prediction, setting 0.-0.3 has no obvious effect on the result|
|
||||
|det_db_box_thresh|float|0.5|DB post-processing filter box threshold, if there is a missing box detected, it can be reduced as appropriate|
|
||||
|det_db_unclip_ratio|float|1.6|Indicates the compactness of the text box, the smaller the value, the closer the text box to the text|
|
||||
|use_polygon_score|bool|false|Whether to use polygon box to calculate bbox score, false means to use rectangle box to calculate. Use rectangular box to calculate faster, and polygonal box more accurate for curved text area.|
|
||||
|visualize|bool|true|Whether to visualize the results,when it is set as true, The prediction result will be save in the image file `./ocr_vis.png`.|
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/rec_crnn # Address of recognition inference model
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt # dictionary file
|
||||
- classifier related parameters
|
||||
|
||||
# show the detection results
|
||||
visualize 1 # Whether to visualize the results,when it is set as 1, The prediction result will be save in the image file `./ocr_vis.png`.
|
||||
```
|
||||
|parameter|data type|default|meaning|
|
||||
| --- | --- | --- | --- |
|
||||
|use_angle_cls|bool|false|Whether to use the direction classifier|
|
||||
|cls_model_dir|string|-|Address of direction classifier inference model|
|
||||
|cls_thresh|float|0.9|Score threshold of the direction classifier|
|
||||
|
||||
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `char_list_file` and `rec_model_dir` in file `tools/config.txt`.
|
||||
- recogniton related parameters
|
||||
|
||||
|parameter|data type|default|meaning|
|
||||
| --- | --- | --- | --- |
|
||||
|rec_model_dir|string|-|Address of recognition inference model|
|
||||
|char_list_file|string|../../ppocr/utils/ppocr_keys_v1.txt|dictionary file|
|
||||
|
||||
* Multi-language inference is also supported in PaddleOCR, you can refer to [recognition tutorial](../../doc/doc_en/recognition_en.md) for more supported languages and models in PaddleOCR. Specifically, if you want to infer using multi-language models, you just need to modify values of `char_list_file` and `rec_model_dir`.
|
||||
|
||||
|
||||
The detection results will be shown on the screen, which is as follows.
|
||||
|
|
|
@ -1,64 +0,0 @@
|
|||
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <include/config.h>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
std::vector<std::string> OCRConfig::split(const std::string &str,
|
||||
const std::string &delim) {
|
||||
std::vector<std::string> res;
|
||||
if ("" == str)
|
||||
return res;
|
||||
char *strs = new char[str.length() + 1];
|
||||
std::strcpy(strs, str.c_str());
|
||||
|
||||
char *d = new char[delim.length() + 1];
|
||||
std::strcpy(d, delim.c_str());
|
||||
|
||||
char *p = std::strtok(strs, d);
|
||||
while (p) {
|
||||
std::string s = p;
|
||||
res.push_back(s);
|
||||
p = std::strtok(NULL, d);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
std::map<std::string, std::string>
|
||||
OCRConfig::LoadConfig(const std::string &config_path) {
|
||||
auto config = Utility::ReadDict(config_path);
|
||||
|
||||
std::map<std::string, std::string> dict;
|
||||
for (int i = 0; i < config.size(); i++) {
|
||||
// pass for empty line or comment
|
||||
if (config[i].size() <= 1 || config[i][0] == '#') {
|
||||
continue;
|
||||
}
|
||||
std::vector<std::string> res = split(config[i], " ");
|
||||
dict[res[0]] = res[1];
|
||||
}
|
||||
return dict;
|
||||
}
|
||||
|
||||
void OCRConfig::PrintConfigInfo() {
|
||||
std::cout << "=======Paddle OCR inference config======" << std::endl;
|
||||
for (auto iter = config_map_.begin(); iter != config_map_.end(); iter++) {
|
||||
std::cout << iter->first << " : " << iter->second << std::endl;
|
||||
}
|
||||
std::cout << "=======End of Paddle OCR inference config======" << std::endl;
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -28,67 +28,198 @@
|
|||
#include <numeric>
|
||||
|
||||
#include <glog/logging.h>
|
||||
#include <include/config.h>
|
||||
#include <include/ocr_det.h>
|
||||
#include <include/ocr_cls.h>
|
||||
#include <include/ocr_rec.h>
|
||||
#include <include/utility.h>
|
||||
#include <sys/stat.h>
|
||||
|
||||
#include <gflags/gflags.h>
|
||||
|
||||
DEFINE_bool(use_gpu, false, "Infering with GPU or CPU.");
|
||||
DEFINE_int32(gpu_id, 0, "Device id of GPU to execute.");
|
||||
DEFINE_int32(gpu_mem, 4000, "GPU id when infering with GPU.");
|
||||
DEFINE_int32(cpu_math_library_num_threads, 10, "Num of threads with CPU.");
|
||||
DEFINE_bool(use_mkldnn, false, "Whether use mkldnn with CPU.");
|
||||
DEFINE_bool(use_tensorrt, false, "Whether use tensorrt.");
|
||||
DEFINE_string(precision, "fp32", "Precision be one of fp32/fp16/int8");
|
||||
DEFINE_bool(benchmark, true, "Whether use benchmark.");
|
||||
DEFINE_string(save_log_path, "./log_output/", "Save benchmark log path.");
|
||||
// detection related
|
||||
DEFINE_string(image_dir, "", "Dir of input image.");
|
||||
DEFINE_string(det_model_dir, "", "Path of det inference model.");
|
||||
DEFINE_int32(max_side_len, 960, "max_side_len of input image.");
|
||||
DEFINE_double(det_db_thresh, 0.3, "Threshold of det_db_thresh.");
|
||||
DEFINE_double(det_db_box_thresh, 0.5, "Threshold of det_db_box_thresh.");
|
||||
DEFINE_double(det_db_unclip_ratio, 1.6, "Threshold of det_db_unclip_ratio.");
|
||||
DEFINE_bool(use_polygon_score, false, "Whether use polygon score.");
|
||||
DEFINE_bool(visualize, true, "Whether show the detection results.");
|
||||
// classification related
|
||||
DEFINE_bool(use_angle_cls, false, "Whether use use_angle_cls.");
|
||||
DEFINE_string(cls_model_dir, "", "Path of cls inference model.");
|
||||
DEFINE_double(cls_thresh, 0.9, "Threshold of cls_thresh.");
|
||||
// recognition related
|
||||
DEFINE_string(rec_model_dir, "", "Path of rec inference model.");
|
||||
DEFINE_string(char_list_file, "../../ppocr/utils/ppocr_keys_v1.txt", "Path of dictionary.");
|
||||
|
||||
|
||||
using namespace std;
|
||||
using namespace cv;
|
||||
using namespace PaddleOCR;
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc < 3) {
|
||||
std::cerr << "[ERROR] usage: " << argv[0]
|
||||
<< " configure_filepath image_path\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
OCRConfig config(argv[1]);
|
||||
void PrintBenchmarkLog(std::string model_name,
|
||||
int batch_size,
|
||||
std::string input_shape,
|
||||
std::vector<double> time_info,
|
||||
int img_num){
|
||||
LOG(INFO) << "----------------------- Config info -----------------------";
|
||||
LOG(INFO) << "runtime_device: " << (FLAGS_use_gpu ? "gpu" : "cpu");
|
||||
LOG(INFO) << "ir_optim: " << "True";
|
||||
LOG(INFO) << "enable_memory_optim: " << "True";
|
||||
LOG(INFO) << "enable_tensorrt: " << FLAGS_use_tensorrt;
|
||||
LOG(INFO) << "enable_mkldnn: " << (FLAGS_use_mkldnn ? "True" : "False");
|
||||
LOG(INFO) << "cpu_math_library_num_threads: " << FLAGS_cpu_math_library_num_threads;
|
||||
LOG(INFO) << "----------------------- Data info -----------------------";
|
||||
LOG(INFO) << "batch_size: " << batch_size;
|
||||
LOG(INFO) << "input_shape: " << input_shape;
|
||||
LOG(INFO) << "data_num: " << img_num;
|
||||
LOG(INFO) << "----------------------- Model info -----------------------";
|
||||
LOG(INFO) << "model_name: " << model_name;
|
||||
LOG(INFO) << "precision: " << FLAGS_precision;
|
||||
LOG(INFO) << "----------------------- Perf info ------------------------";
|
||||
LOG(INFO) << "Total time spent(ms): "
|
||||
<< std::accumulate(time_info.begin(), time_info.end(), 0);
|
||||
LOG(INFO) << "preprocess_time(ms): " << time_info[0] / img_num
|
||||
<< ", inference_time(ms): " << time_info[1] / img_num
|
||||
<< ", postprocess_time(ms): " << time_info[2] / img_num;
|
||||
}
|
||||
|
||||
config.PrintConfigInfo();
|
||||
|
||||
std::string img_path(argv[2]);
|
||||
std::vector<std::string> all_img_names;
|
||||
Utility::GetAllFiles((char *)img_path.c_str(), all_img_names);
|
||||
static bool PathExists(const std::string& path){
|
||||
#ifdef _WIN32
|
||||
struct _stat buffer;
|
||||
return (_stat(path.c_str(), &buffer) == 0);
|
||||
#else
|
||||
struct stat buffer;
|
||||
return (stat(path.c_str(), &buffer) == 0);
|
||||
#endif // !_WIN32
|
||||
}
|
||||
|
||||
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||
config.det_db_box_thresh, config.det_db_unclip_ratio,
|
||||
config.use_polygon_score, config.visualize,
|
||||
config.use_tensorrt, config.use_fp16);
|
||||
|
||||
Classifier *cls = nullptr;
|
||||
if (config.use_angle_cls == true) {
|
||||
cls = new Classifier(config.cls_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.cls_thresh,
|
||||
config.use_tensorrt, config.use_fp16);
|
||||
}
|
||||
int main_det(std::vector<cv::String> cv_all_img_names) {
|
||||
std::vector<double> time_info = {0, 0, 0};
|
||||
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
|
||||
FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads,
|
||||
FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh,
|
||||
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
|
||||
FLAGS_use_polygon_score, FLAGS_visualize,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
|
||||
CRNNRecognizer rec(config.rec_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.char_list_file,
|
||||
config.use_tensorrt, config.use_fp16);
|
||||
for (int i = 0; i < cv_all_img_names.size(); ++i) {
|
||||
LOG(INFO) << "The predict img: " << cv_all_img_names[i];
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
|
||||
for (auto img_dir : all_img_names) {
|
||||
LOG(INFO) << "The predict img: " << img_dir;
|
||||
|
||||
cv::Mat srcimg = cv::imread(img_dir, cv::IMREAD_COLOR);
|
||||
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << img_path
|
||||
<< "\n";
|
||||
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl;
|
||||
exit(1);
|
||||
}
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
std::vector<double> det_times;
|
||||
|
||||
det.Run(srcimg, boxes);
|
||||
det.Run(srcimg, boxes, &det_times);
|
||||
|
||||
time_info[0] += det_times[0];
|
||||
time_info[1] += det_times[1];
|
||||
time_info[2] += det_times[2];
|
||||
}
|
||||
|
||||
if (FLAGS_benchmark) {
|
||||
PrintBenchmarkLog("det", 1, "dynamic", time_info, cv_all_img_names.size());
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
int main_rec(std::vector<cv::String> cv_all_img_names) {
|
||||
std::vector<double> time_info = {0, 0, 0};
|
||||
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
|
||||
FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads,
|
||||
FLAGS_use_mkldnn, FLAGS_char_list_file,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
|
||||
for (int i = 0; i < cv_all_img_names.size(); ++i) {
|
||||
LOG(INFO) << "The predict img: " << cv_all_img_names[i];
|
||||
|
||||
cv::Mat srcimg = cv::imread(cv_all_img_names[i], cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<double> rec_times;
|
||||
rec.Run(srcimg, &rec_times);
|
||||
|
||||
time_info[0] += rec_times[0];
|
||||
time_info[1] += rec_times[1];
|
||||
time_info[2] += rec_times[2];
|
||||
}
|
||||
|
||||
if (FLAGS_benchmark) {
|
||||
PrintBenchmarkLog("rec", 1, "dynamic", time_info, cv_all_img_names.size());
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
int main_system(std::vector<cv::String> cv_all_img_names) {
|
||||
DBDetector det(FLAGS_det_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
|
||||
FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads,
|
||||
FLAGS_use_mkldnn, FLAGS_max_side_len, FLAGS_det_db_thresh,
|
||||
FLAGS_det_db_box_thresh, FLAGS_det_db_unclip_ratio,
|
||||
FLAGS_use_polygon_score, FLAGS_visualize,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
|
||||
Classifier *cls = nullptr;
|
||||
if (FLAGS_use_angle_cls) {
|
||||
cls = new Classifier(FLAGS_cls_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
|
||||
FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads,
|
||||
FLAGS_use_mkldnn, FLAGS_cls_thresh,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
}
|
||||
|
||||
CRNNRecognizer rec(FLAGS_rec_model_dir, FLAGS_use_gpu, FLAGS_gpu_id,
|
||||
FLAGS_gpu_mem, FLAGS_cpu_math_library_num_threads,
|
||||
FLAGS_use_mkldnn, FLAGS_char_list_file,
|
||||
FLAGS_use_tensorrt, FLAGS_precision);
|
||||
|
||||
auto start = std::chrono::system_clock::now();
|
||||
|
||||
for (int i = 0; i < cv_all_img_names.size(); ++i) {
|
||||
LOG(INFO) << "The predict img: " << cv_all_img_names[i];
|
||||
|
||||
cv::Mat srcimg = cv::imread(FLAGS_image_dir, cv::IMREAD_COLOR);
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << cv_all_img_names[i] << endl;
|
||||
exit(1);
|
||||
}
|
||||
std::vector<std::vector<std::vector<int>>> boxes;
|
||||
std::vector<double> det_times;
|
||||
std::vector<double> rec_times;
|
||||
|
||||
det.Run(srcimg, boxes, &det_times);
|
||||
|
||||
cv::Mat crop_img;
|
||||
for (int j = 0; j < boxes.size(); j++) {
|
||||
crop_img = Utility::GetRotateCropImage(srcimg, boxes[j]);
|
||||
|
||||
if (cls != nullptr) {
|
||||
crop_img = cls->Run(crop_img);
|
||||
}
|
||||
rec.Run(crop_img, &rec_times);
|
||||
}
|
||||
|
||||
rec.Run(boxes, srcimg, cls);
|
||||
auto end = std::chrono::system_clock::now();
|
||||
auto duration =
|
||||
std::chrono::duration_cast<std::chrono::microseconds>(end - start);
|
||||
|
@ -101,3 +232,72 @@ int main(int argc, char **argv) {
|
|||
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
void check_params(char* mode) {
|
||||
if (strcmp(mode, "det")==0) {
|
||||
if (FLAGS_det_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[det]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (strcmp(mode, "rec")==0) {
|
||||
if (FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) {
|
||||
std::cout << "Usage[rec]: ./ppocr --rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (strcmp(mode, "system")==0) {
|
||||
if ((FLAGS_det_model_dir.empty() || FLAGS_rec_model_dir.empty() || FLAGS_image_dir.empty()) ||
|
||||
(FLAGS_use_angle_cls && FLAGS_cls_model_dir.empty())) {
|
||||
std::cout << "Usage[system without angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
|
||||
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
std::cout << "Usage[system with angle cls]: ./ppocr --det_model_dir=/PATH/TO/DET_INFERENCE_MODEL/ "
|
||||
<< "--use_angle_cls=true "
|
||||
<< "--cls_model_dir=/PATH/TO/CLS_INFERENCE_MODEL/ "
|
||||
<< "--rec_model_dir=/PATH/TO/REC_INFERENCE_MODEL/ "
|
||||
<< "--image_dir=/PATH/TO/INPUT/IMAGE/" << std::endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
if (FLAGS_precision != "fp32" && FLAGS_precision != "fp16" && FLAGS_precision != "int8") {
|
||||
cout << "precison should be 'fp32'(default), 'fp16' or 'int8'. " << endl;
|
||||
exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
if (argc<=1 || (strcmp(argv[1], "det")!=0 && strcmp(argv[1], "rec")!=0 && strcmp(argv[1], "system")!=0)) {
|
||||
std::cout << "Please choose one mode of [det, rec, system] !" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
std::cout << "mode: " << argv[1] << endl;
|
||||
|
||||
// Parsing command-line
|
||||
google::ParseCommandLineFlags(&argc, &argv, true);
|
||||
check_params(argv[1]);
|
||||
|
||||
if (!PathExists(FLAGS_image_dir)) {
|
||||
std::cerr << "[ERROR] image path not exist! image_dir: " << FLAGS_image_dir << endl;
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<cv::String> cv_all_img_names;
|
||||
cv::glob(FLAGS_image_dir, cv_all_img_names);
|
||||
std::cout << "total images num: " << cv_all_img_names.size() << endl;
|
||||
|
||||
if (strcmp(argv[1], "det")==0) {
|
||||
return main_det(cv_all_img_names);
|
||||
}
|
||||
if (strcmp(argv[1], "rec")==0) {
|
||||
return main_rec(cv_all_img_names);
|
||||
}
|
||||
if (strcmp(argv[1], "system")==0) {
|
||||
return main_system(cv_all_img_names);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -77,10 +77,16 @@ void Classifier::LoadModel(const std::string &model_dir) {
|
|||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
if (this->use_tensorrt_) {
|
||||
auto precision = paddle_infer::Config::Precision::kFloat32;
|
||||
if (this->precision_ == "fp16") {
|
||||
precision = paddle_infer::Config::Precision::kHalf;
|
||||
}
|
||||
if (this->precision_ == "int8") {
|
||||
precision = paddle_infer::Config::Precision::kInt8;
|
||||
}
|
||||
config.EnableTensorRtEngine(
|
||||
1 << 20, 10, 3,
|
||||
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
|
||||
: paddle_infer::Config::Precision::kFloat32,
|
||||
precision,
|
||||
false, false);
|
||||
}
|
||||
} else {
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
#include <include/ocr_det.h>
|
||||
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void DBDetector::LoadModel(const std::string &model_dir) {
|
||||
|
@ -25,10 +26,16 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
if (this->use_tensorrt_) {
|
||||
auto precision = paddle_infer::Config::Precision::kFloat32;
|
||||
if (this->precision_ == "fp16") {
|
||||
precision = paddle_infer::Config::Precision::kHalf;
|
||||
}
|
||||
if (this->precision_ == "int8") {
|
||||
precision = paddle_infer::Config::Precision::kInt8;
|
||||
}
|
||||
config.EnableTensorRtEngine(
|
||||
1 << 20, 10, 3,
|
||||
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
|
||||
: paddle_infer::Config::Precision::kFloat32,
|
||||
precision,
|
||||
false, false);
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"x", {1, 3, 50, 50}},
|
||||
|
@ -90,13 +97,16 @@ void DBDetector::LoadModel(const std::string &model_dir) {
|
|||
}
|
||||
|
||||
void DBDetector::Run(cv::Mat &img,
|
||||
std::vector<std::vector<std::vector<int>>> &boxes) {
|
||||
std::vector<std::vector<std::vector<int>>> &boxes,
|
||||
std::vector<double> *times) {
|
||||
float ratio_h{};
|
||||
float ratio_w{};
|
||||
|
||||
cv::Mat srcimg;
|
||||
cv::Mat resize_img;
|
||||
img.copyTo(srcimg);
|
||||
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
this->resize_op_.Run(img, resize_img, this->max_side_len_, ratio_h, ratio_w,
|
||||
this->use_tensorrt_);
|
||||
|
||||
|
@ -105,12 +115,15 @@ void DBDetector::Run(cv::Mat &img,
|
|||
|
||||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
auto preprocess_end = std::chrono::steady_clock::now();
|
||||
|
||||
// Inference.
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
auto inference_start = std::chrono::steady_clock::now();
|
||||
input_t->CopyFromCpu(input.data());
|
||||
|
||||
this->predictor_->Run();
|
||||
|
||||
std::vector<float> out_data;
|
||||
|
@ -122,7 +135,9 @@ void DBDetector::Run(cv::Mat &img,
|
|||
|
||||
out_data.resize(out_num);
|
||||
output_t->CopyToCpu(out_data.data());
|
||||
auto inference_end = std::chrono::steady_clock::now();
|
||||
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
int n2 = output_shape[2];
|
||||
int n3 = output_shape[3];
|
||||
int n = n2 * n3;
|
||||
|
@ -150,6 +165,15 @@ void DBDetector::Run(cv::Mat &img,
|
|||
this->det_db_unclip_ratio_, this->use_polygon_score_);
|
||||
|
||||
boxes = post_processor_.FilterTagDetRes(boxes, ratio_h, ratio_w, srcimg);
|
||||
auto postprocess_end = std::chrono::steady_clock::now();
|
||||
std::cout << "Detected boxes num: " << boxes.size() << endl;
|
||||
|
||||
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
|
||||
times->push_back(double(preprocess_diff.count() * 1000));
|
||||
std::chrono::duration<float> inference_diff = inference_end - inference_start;
|
||||
times->push_back(double(inference_diff.count() * 1000));
|
||||
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
|
||||
times->push_back(double(postprocess_diff.count() * 1000));
|
||||
|
||||
//// visualization
|
||||
if (this->visualize_) {
|
||||
|
|
|
@ -16,25 +16,14 @@
|
|||
|
||||
namespace PaddleOCR {
|
||||
|
||||
void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
||||
cv::Mat &img, Classifier *cls) {
|
||||
void CRNNRecognizer::Run(cv::Mat &img, std::vector<double> *times) {
|
||||
cv::Mat srcimg;
|
||||
img.copyTo(srcimg);
|
||||
cv::Mat crop_img;
|
||||
cv::Mat resize_img;
|
||||
|
||||
std::cout << "The predicted text is :" << std::endl;
|
||||
int index = 0;
|
||||
for (int i = 0; i < boxes.size(); i++) {
|
||||
crop_img = GetRotateCropImage(srcimg, boxes[i]);
|
||||
|
||||
if (cls != nullptr) {
|
||||
crop_img = cls->Run(crop_img);
|
||||
}
|
||||
|
||||
float wh_ratio = float(crop_img.cols) / float(crop_img.rows);
|
||||
|
||||
this->resize_op_.Run(crop_img, resize_img, wh_ratio, this->use_tensorrt_);
|
||||
float wh_ratio = float(srcimg.cols) / float(srcimg.rows);
|
||||
auto preprocess_start = std::chrono::steady_clock::now();
|
||||
this->resize_op_.Run(srcimg, resize_img, wh_ratio, this->use_tensorrt_);
|
||||
|
||||
this->normalize_op_.Run(&resize_img, this->mean_, this->scale_,
|
||||
this->is_scale_);
|
||||
|
@ -42,11 +31,13 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
std::vector<float> input(1 * 3 * resize_img.rows * resize_img.cols, 0.0f);
|
||||
|
||||
this->permute_op_.Run(&resize_img, input.data());
|
||||
auto preprocess_end = std::chrono::steady_clock::now();
|
||||
|
||||
// Inference.
|
||||
auto input_names = this->predictor_->GetInputNames();
|
||||
auto input_t = this->predictor_->GetInputHandle(input_names[0]);
|
||||
input_t->Reshape({1, 3, resize_img.rows, resize_img.cols});
|
||||
auto inference_start = std::chrono::steady_clock::now();
|
||||
input_t->CopyFromCpu(input.data());
|
||||
this->predictor_->Run();
|
||||
|
||||
|
@ -60,8 +51,10 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
predict_batch.resize(out_num);
|
||||
|
||||
output_t->CopyToCpu(predict_batch.data());
|
||||
auto inference_end = std::chrono::steady_clock::now();
|
||||
|
||||
// ctc decode
|
||||
auto postprocess_start = std::chrono::steady_clock::now();
|
||||
std::vector<std::string> str_res;
|
||||
int argmax_idx;
|
||||
int last_index = 0;
|
||||
|
@ -84,12 +77,19 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
}
|
||||
last_index = argmax_idx;
|
||||
}
|
||||
auto postprocess_end = std::chrono::steady_clock::now();
|
||||
score /= count;
|
||||
for (int i = 0; i < str_res.size(); i++) {
|
||||
std::cout << str_res[i];
|
||||
}
|
||||
std::cout << "\tscore: " << score << std::endl;
|
||||
}
|
||||
|
||||
std::chrono::duration<float> preprocess_diff = preprocess_end - preprocess_start;
|
||||
times->push_back(double(preprocess_diff.count() * 1000));
|
||||
std::chrono::duration<float> inference_diff = inference_end - inference_start;
|
||||
times->push_back(double(inference_diff.count() * 1000));
|
||||
std::chrono::duration<float> postprocess_diff = postprocess_end - postprocess_start;
|
||||
times->push_back(double(postprocess_diff.count() * 1000));
|
||||
}
|
||||
|
||||
void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
||||
|
@ -101,10 +101,16 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
if (this->use_gpu_) {
|
||||
config.EnableUseGpu(this->gpu_mem_, this->gpu_id_);
|
||||
if (this->use_tensorrt_) {
|
||||
auto precision = paddle_infer::Config::Precision::kFloat32;
|
||||
if (this->precision_ == "fp16") {
|
||||
precision = paddle_infer::Config::Precision::kHalf;
|
||||
}
|
||||
if (this->precision_ == "int8") {
|
||||
precision = paddle_infer::Config::Precision::kInt8;
|
||||
}
|
||||
config.EnableTensorRtEngine(
|
||||
1 << 20, 10, 3,
|
||||
this->use_fp16_ ? paddle_infer::Config::Precision::kHalf
|
||||
: paddle_infer::Config::Precision::kFloat32,
|
||||
precision,
|
||||
false, false);
|
||||
std::map<std::string, std::vector<int>> min_input_shape = {
|
||||
{"x", {1, 3, 32, 10}}};
|
||||
|
@ -138,59 +144,4 @@ void CRNNRecognizer::LoadModel(const std::string &model_dir) {
|
|||
this->predictor_ = CreatePredictor(config);
|
||||
}
|
||||
|
||||
cv::Mat CRNNRecognizer::GetRotateCropImage(const cv::Mat &srcimage,
|
||||
std::vector<std::vector<int>> box) {
|
||||
cv::Mat image;
|
||||
srcimage.copyTo(image);
|
||||
std::vector<std::vector<int>> points = box;
|
||||
|
||||
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
|
||||
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
|
||||
cv::Mat img_crop;
|
||||
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
|
||||
|
||||
for (int i = 0; i < points.size(); i++) {
|
||||
points[i][0] -= left;
|
||||
points[i][1] -= top;
|
||||
}
|
||||
|
||||
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
|
||||
pow(points[0][1] - points[1][1], 2)));
|
||||
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
|
||||
pow(points[0][1] - points[3][1], 2)));
|
||||
|
||||
cv::Point2f pts_std[4];
|
||||
pts_std[0] = cv::Point2f(0., 0.);
|
||||
pts_std[1] = cv::Point2f(img_crop_width, 0.);
|
||||
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
|
||||
pts_std[3] = cv::Point2f(0.f, img_crop_height);
|
||||
|
||||
cv::Point2f pointsf[4];
|
||||
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
|
||||
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
|
||||
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
|
||||
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
|
||||
|
||||
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
|
||||
|
||||
cv::Mat dst_img;
|
||||
cv::warpPerspective(img_crop, dst_img, M,
|
||||
cv::Size(img_crop_width, img_crop_height),
|
||||
cv::BORDER_REPLICATE);
|
||||
|
||||
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
|
||||
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
|
||||
cv::transpose(dst_img, srcCopy);
|
||||
cv::flip(srcCopy, srcCopy, 0);
|
||||
return srcCopy;
|
||||
} else {
|
||||
return dst_img;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
// limitations under the License.
|
||||
|
||||
#include <include/postprocess_op.h>
|
||||
#include <include/clipper.cpp>
|
||||
|
||||
namespace PaddleOCR {
|
||||
|
||||
|
|
|
@ -47,16 +47,13 @@ void Normalize::Run(cv::Mat *im, const std::vector<float> &mean,
|
|||
e /= 255.0;
|
||||
}
|
||||
(*im).convertTo(*im, CV_32FC3, e);
|
||||
for (int h = 0; h < im->rows; h++) {
|
||||
for (int w = 0; w < im->cols; w++) {
|
||||
im->at<cv::Vec3f>(h, w)[0] =
|
||||
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) * scale[0];
|
||||
im->at<cv::Vec3f>(h, w)[1] =
|
||||
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) * scale[1];
|
||||
im->at<cv::Vec3f>(h, w)[2] =
|
||||
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) * scale[2];
|
||||
}
|
||||
std::vector<cv::Mat> bgr_channels(3);
|
||||
cv::split(*im, bgr_channels);
|
||||
for (auto i = 0; i < bgr_channels.size(); i++) {
|
||||
bgr_channels[i].convertTo(bgr_channels[i], CV_32FC1, 1.0 * scale[i],
|
||||
(0.0 - mean[i]) * scale[i]);
|
||||
}
|
||||
cv::merge(bgr_channels, *im);
|
||||
}
|
||||
|
||||
void ResizeImgType0::Run(const cv::Mat &img, cv::Mat &resize_img,
|
||||
|
|
|
@ -92,4 +92,59 @@ void Utility::GetAllFiles(const char *dir_name,
|
|||
}
|
||||
}
|
||||
|
||||
cv::Mat Utility::GetRotateCropImage(const cv::Mat &srcimage,
|
||||
std::vector<std::vector<int>> box) {
|
||||
cv::Mat image;
|
||||
srcimage.copyTo(image);
|
||||
std::vector<std::vector<int>> points = box;
|
||||
|
||||
int x_collect[4] = {box[0][0], box[1][0], box[2][0], box[3][0]};
|
||||
int y_collect[4] = {box[0][1], box[1][1], box[2][1], box[3][1]};
|
||||
int left = int(*std::min_element(x_collect, x_collect + 4));
|
||||
int right = int(*std::max_element(x_collect, x_collect + 4));
|
||||
int top = int(*std::min_element(y_collect, y_collect + 4));
|
||||
int bottom = int(*std::max_element(y_collect, y_collect + 4));
|
||||
|
||||
cv::Mat img_crop;
|
||||
image(cv::Rect(left, top, right - left, bottom - top)).copyTo(img_crop);
|
||||
|
||||
for (int i = 0; i < points.size(); i++) {
|
||||
points[i][0] -= left;
|
||||
points[i][1] -= top;
|
||||
}
|
||||
|
||||
int img_crop_width = int(sqrt(pow(points[0][0] - points[1][0], 2) +
|
||||
pow(points[0][1] - points[1][1], 2)));
|
||||
int img_crop_height = int(sqrt(pow(points[0][0] - points[3][0], 2) +
|
||||
pow(points[0][1] - points[3][1], 2)));
|
||||
|
||||
cv::Point2f pts_std[4];
|
||||
pts_std[0] = cv::Point2f(0., 0.);
|
||||
pts_std[1] = cv::Point2f(img_crop_width, 0.);
|
||||
pts_std[2] = cv::Point2f(img_crop_width, img_crop_height);
|
||||
pts_std[3] = cv::Point2f(0.f, img_crop_height);
|
||||
|
||||
cv::Point2f pointsf[4];
|
||||
pointsf[0] = cv::Point2f(points[0][0], points[0][1]);
|
||||
pointsf[1] = cv::Point2f(points[1][0], points[1][1]);
|
||||
pointsf[2] = cv::Point2f(points[2][0], points[2][1]);
|
||||
pointsf[3] = cv::Point2f(points[3][0], points[3][1]);
|
||||
|
||||
cv::Mat M = cv::getPerspectiveTransform(pointsf, pts_std);
|
||||
|
||||
cv::Mat dst_img;
|
||||
cv::warpPerspective(img_crop, dst_img, M,
|
||||
cv::Size(img_crop_width, img_crop_height),
|
||||
cv::BORDER_REPLICATE);
|
||||
|
||||
if (float(dst_img.rows) >= float(dst_img.cols) * 1.5) {
|
||||
cv::Mat srcCopy = cv::Mat(dst_img.rows, dst_img.cols, dst_img.depth());
|
||||
cv::transpose(dst_img, srcCopy);
|
||||
cv::flip(srcCopy, srcCopy, 0);
|
||||
return srcCopy;
|
||||
} else {
|
||||
return dst_img;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace PaddleOCR
|
|
@ -1,31 +0,0 @@
|
|||
# model load config
|
||||
use_gpu 0
|
||||
gpu_id 0
|
||||
gpu_mem 4000
|
||||
cpu_math_library_num_threads 10
|
||||
use_mkldnn 0
|
||||
|
||||
# det config
|
||||
max_side_len 960
|
||||
det_db_thresh 0.3
|
||||
det_db_box_thresh 0.5
|
||||
det_db_unclip_ratio 1.6
|
||||
use_polygon_score 1
|
||||
det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
|
||||
|
||||
# cls config
|
||||
use_angle_cls 0
|
||||
cls_model_dir ./inference/ch_ppocr_mobile_v2.0_cls_infer/
|
||||
cls_thresh 0.9
|
||||
|
||||
# rec config
|
||||
rec_model_dir ./inference/ch_ppocr_mobile_v2.0_rec_infer/
|
||||
char_list_file ../../ppocr/utils/ppocr_keys_v1.txt
|
||||
|
||||
# show the detection results
|
||||
visualize 0
|
||||
|
||||
# use_tensorrt
|
||||
use_tensorrt 0
|
||||
use_fp16 0
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
|
||||
./build/ocr_system ./tools/config.txt ../../doc/imgs/12.jpg
|
|
@ -29,7 +29,8 @@ deploy/hubserving/ocr_system/
|
|||
### 1. 准备环境
|
||||
```shell
|
||||
# 安装paddlehub
|
||||
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# paddlehub 需要 python>3.6.2
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. 下载推理模型
|
||||
|
|
|
@ -30,7 +30,8 @@ The following steps take the 2-stage series service as an example. If only the d
|
|||
### 1. Prepare the environment
|
||||
```shell
|
||||
# Install paddlehub
|
||||
pip3 install paddlehub==1.8.3 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
# python>3.6.2 is required bt paddlehub
|
||||
pip3 install paddlehub==2.1.0 --upgrade -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
### 2. Download inference model
|
||||
|
|
|
@ -37,6 +37,17 @@ from paddleslim.dygraph.quant import QAT
|
|||
from ppocr.data import build_dataloader
|
||||
|
||||
|
||||
def export_single_model(quanter, model, infer_shape, save_path, logger):
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
############################################################################################################
|
||||
# 1. quantization configs
|
||||
|
@ -76,14 +87,21 @@ def main():
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
# get QAT model
|
||||
quanter = QAT(config=quant_config)
|
||||
quanter.quantize(model)
|
||||
|
||||
init_model(config, model, logger)
|
||||
init_model(config, model)
|
||||
model.eval()
|
||||
|
||||
# build metric
|
||||
|
@ -92,25 +110,30 @@ def main():
|
|||
# build dataloader
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
model_type = config['Architecture']['model_type']
|
||||
# start eval
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class, model_type, use_srn)
|
||||
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metirc.items():
|
||||
for k, v in metric.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
infer_shape = [3, 32, 100] if config['Architecture'][
|
||||
'model_type'] != "det" else [3, 640, 640]
|
||||
|
||||
quanter.save_quantized_model(
|
||||
model,
|
||||
save_path,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
logger.info('inference QAT model is saved to {}'.format(save_path))
|
||||
save_path = config["Global"]["save_inference_dir"]
|
||||
|
||||
arch_config = config["Architecture"]
|
||||
if arch_config["algorithm"] in ["Distillation", ]: # distillation model
|
||||
for idx, name in enumerate(model.model_name_list):
|
||||
sub_model_save_path = os.path.join(save_path, name, "inference")
|
||||
export_single_model(quanter, model.model_list[idx], infer_shape,
|
||||
sub_model_save_path, logger)
|
||||
else:
|
||||
save_path = os.path.join(save_path, "inference")
|
||||
export_single_model(quanter, model, infer_shape, save_path, logger)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -109,9 +109,18 @@ def main(config, device, logger, vdl_writer):
|
|||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
if config['Architecture']["algorithm"] in ["Distillation",
|
||||
]: # distillation model
|
||||
for key in config['Architecture']["Models"]:
|
||||
config['Architecture']["Models"][key]["Head"][
|
||||
'out_channels'] = char_num
|
||||
else: # base rec model
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
if config['Global']['distributed']:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
|
@ -132,8 +141,6 @@ def main(config, device, logger, vdl_writer):
|
|||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
quanter = QAT(config=quant_config, act_preprocess=PACT)
|
||||
quanter.quantize(model)
|
||||
|
||||
# start train
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
|
|
|
@ -18,9 +18,9 @@ PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支
|
|||
|
||||
```
|
||||
# 将官网下载的标签文件转换为 train_icdar2015_label.txt
|
||||
python gen_label.py --mode="det" --root_path="icdar_c4_train_imgs/" \
|
||||
--input_path="ch4_training_localization_transcription_gt" \
|
||||
--output_label="train_icdar2015_label.txt"
|
||||
python gen_label.py --mode="det" --root_path="/path/to/icdar_c4_train_imgs/" \
|
||||
--input_path="/path/to/ch4_training_localization_transcription_gt" \
|
||||
--output_label="/path/to/train_icdar2015_label.txt"
|
||||
```
|
||||
|
||||
解压数据集和下载标注文件后,PaddleOCR/train_data/ 有两个文件夹和两个文件,分别是:
|
||||
|
|
|
@ -147,12 +147,12 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
|||
|
||||
如果输入图片的分辨率比较大,而且想使用更大的分辨率预测,可以设置det_limit_side_len 为想要的值,比如1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
如果想使用CPU进行预测,执行命令如下
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/2.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB文本检测模型推理"></a>
|
||||
|
@ -221,7 +221,7 @@ python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Gl
|
|||
|
||||
```
|
||||
|
||||
**SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,**可以执行如下命令:
|
||||
SAST文本检测模型推理,需要设置参数`--det_algorithm="SAST"`,同时,还需要增加参数`--det_sast_polygon=True`,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
```
|
||||
|
|
|
@ -0,0 +1,251 @@
|
|||
# 知识蒸馏
|
||||
|
||||
|
||||
## 1. 简介
|
||||
|
||||
### 1.1 知识蒸馏介绍
|
||||
|
||||
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
|
||||
|
||||
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
|
||||
|
||||
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
|
||||
|
||||
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
|
||||
|
||||
### 1.2 PaddleOCR知识蒸馏简介
|
||||
|
||||
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
|
||||
|
||||
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
|
||||
|
||||
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
|
||||
|
||||
PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
|
||||
- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
|
||||
- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
|
||||
- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
|
||||
|
||||
|
||||
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
|
||||
|
||||
|
||||
|
||||
## 2. 配置文件解析
|
||||
|
||||
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
|
||||
|
||||
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
|
||||
|
||||
### 2.1 识别配置文件解析
|
||||
|
||||
配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)。
|
||||
|
||||
#### 2.1.1 模型结构
|
||||
|
||||
知识蒸馏任务中,模型结构配置如下所示。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
|
||||
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
|
||||
algorithm: Distillation # 算法名称
|
||||
Models: # 模型,包含子网络的配置信息
|
||||
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
|
||||
pretrained: # 该子网络是否需要加载预训练模型
|
||||
freeze_params: false # 是否需要固定参数
|
||||
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
|
||||
model_type: *model_type # 模型类别
|
||||
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
|
||||
pretrained: # 下面的组网参数同上
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student`与`Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
|
||||
|
||||
```yaml
|
||||
Architecture:
|
||||
model_type: &model_type "rec"
|
||||
name: DistillationModel
|
||||
algorithm: Distillation
|
||||
Models:
|
||||
Teacher:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student:
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
|
||||
pretrained:
|
||||
freeze_params: false
|
||||
return_all_feats: true
|
||||
model_type: *model_type
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV1Enhance
|
||||
scale: 0.5
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 64
|
||||
Head:
|
||||
name: CTCHead
|
||||
mid_channels: 96
|
||||
fc_decay: 0.00002
|
||||
```
|
||||
|
||||
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`。
|
||||
|
||||
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)。
|
||||
|
||||
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student`与`Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
|
||||
|
||||
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out`,`value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
|
||||
|
||||
```json
|
||||
{
|
||||
"Teacher": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
},
|
||||
"Student": {
|
||||
"backbone_out": tensor,
|
||||
"neck_out": tensor,
|
||||
"head_out": tensor,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 2.1.2 损失函数
|
||||
|
||||
知识蒸馏任务中,损失函数配置如下所示。
|
||||
|
||||
```yaml
|
||||
Loss:
|
||||
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
|
||||
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
|
||||
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
|
||||
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
|
||||
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
|
||||
weight: 1.0 # 权重
|
||||
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
|
||||
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
|
||||
- ["Student", "Teacher"]
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
- DistillationDistanceLoss: # 蒸馏的距离损失函数
|
||||
weight: 1.0 # 权重
|
||||
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
|
||||
model_name_pairs: # 用于计算distance loss的子网络名称对
|
||||
- ["Student", "Teacher"]
|
||||
key: backbone_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
|
||||
|
||||
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
|
||||
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
|
||||
- `Student`和`Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
|
||||
- `Student`和`Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
|
||||
|
||||
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)。
|
||||
|
||||
|
||||
#### 2.1.3 后处理
|
||||
|
||||
知识蒸馏任务中,后处理配置如下所示。
|
||||
|
||||
```yaml
|
||||
PostProcess:
|
||||
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
|
||||
key: head_out # 取子网络输出dict中,该key对应的tensor
|
||||
```
|
||||
|
||||
以上述配置为例,最终会同时计算`Student`和`Teahcer` 2个子网络的CTC解码输出,返回一个`dict`,`key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
|
||||
|
||||
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
|
||||
|
||||
|
||||
#### 2.1.4 指标计算
|
||||
|
||||
知识蒸馏任务中,指标计算配置如下所示。
|
||||
|
||||
```yaml
|
||||
Metric:
|
||||
name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
|
||||
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
|
||||
main_indicator: acc # 指标的名称
|
||||
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
|
||||
```
|
||||
|
||||
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
|
||||
|
||||
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
|
||||
|
||||
|
||||
### 2.2 检测配置文件解析
|
||||
|
||||
* coming soon!
|
|
@ -331,6 +331,8 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
```
|
||||
|
||||
意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
|
||||
|
||||
2. 手动修改配置文件
|
||||
|
||||
您也可以手动修改模版中的以下几个字段:
|
||||
|
@ -376,7 +378,9 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
|
|||
|
||||
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
|
||||
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
|
||||
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
|
||||
* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
|
||||
* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
|
||||
|
||||
|
|
|
@ -5,23 +5,29 @@
|
|||
### 1.1 安装whl包
|
||||
|
||||
pip安装
|
||||
|
||||
```bash
|
||||
pip install "paddleocr>=2.0.1" # 推荐使用2.0.1+版本
|
||||
```
|
||||
|
||||
本地构建并安装
|
||||
|
||||
```bash
|
||||
python3 setup.py bdist_wheel
|
||||
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
|
||||
```
|
||||
|
||||
## 2 使用
|
||||
|
||||
### 2.1 代码使用
|
||||
|
||||
paddleocr whl包会自动下载ppocr轻量级模型作为默认模型,可以根据第3节**自定义模型**进行自定义更换。
|
||||
|
||||
* 检测+方向分类器+识别全流程
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
|
@ -32,6 +38,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -40,31 +47,36 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
<div align="center">
|
||||
<img src="../imgs_results/whl/11_det_rec.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
* 检测+识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path,cls=False)
|
||||
result = ocr.ocr(img_path, cls=False)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -73,37 +85,45 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
<div align="center">
|
||||
<img src="../imgs_results/whl/11_det_rec.jpg" width="800">
|
||||
</div>
|
||||
|
||||
|
||||
* 方向分类器+识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行检测
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path, rec=False)
|
||||
|
@ -118,13 +138,16 @@ im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/Pa
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含文本框
|
||||
|
||||
```bash
|
||||
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
|
||||
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
|
||||
[[128.0, 397.0], [273.0, 397.0], [273.0, 414.0], [128.0, 414.0]]
|
||||
......
|
||||
```
|
||||
|
||||
结果可视化
|
||||
|
||||
|
||||
|
@ -133,29 +156,37 @@ im_show.save('result.jpg')
|
|||
</div>
|
||||
|
||||
* 单独执行识别
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
ocr = PaddleOCR() # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行方向分类器
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs_words/ch/word_1.jpg'
|
||||
result = ocr.ocr(img_path, det=False, rec=False, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含分类结果和分类置信度
|
||||
|
||||
```bash
|
||||
['0', 0.9999924]
|
||||
```
|
||||
|
@ -163,15 +194,19 @@ for line in result:
|
|||
### 2.2 通过命令行使用
|
||||
|
||||
查看帮助信息
|
||||
|
||||
```bash
|
||||
paddleocr -h
|
||||
```
|
||||
|
||||
* 检测+方向分类器+识别全流程
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
|
@ -180,10 +215,13 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
|
|||
```
|
||||
|
||||
* 检测+识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
|
||||
```
|
||||
|
||||
结果是一个list,每个item包含了文本框,文字和识别置信度
|
||||
|
||||
```bash
|
||||
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
|
||||
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
|
||||
|
@ -192,20 +230,25 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
|
|||
```
|
||||
|
||||
* 方向分类器+识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行检测
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含文本框
|
||||
|
||||
```bash
|
||||
[[26.0, 457.0], [137.0, 457.0], [137.0, 477.0], [26.0, 477.0]]
|
||||
[[25.0, 425.0], [372.0, 425.0], [372.0, 448.0], [25.0, 448.0]]
|
||||
|
@ -214,34 +257,42 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
|
|||
```
|
||||
|
||||
* 单独执行识别
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --det false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含识别结果和识别置信度
|
||||
|
||||
```bash
|
||||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行方向分类器
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false --rec false
|
||||
```
|
||||
|
||||
结果是一个list,每个item只包含分类结果和分类置信度
|
||||
|
||||
```bash
|
||||
['0', 0.9999924]
|
||||
```
|
||||
|
||||
## 3 自定义模型
|
||||
当内置模型无法满足需求时,需要使用到自己训练的模型。
|
||||
首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
|
||||
|
||||
当内置模型无法满足需求时,需要使用到自己训练的模型。 首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
|
||||
|
||||
### 3.1 代码使用
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# 模型路径下必须含有model和params文件
|
||||
ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}', rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}', use_angle_cls=True)
|
||||
ocr = PaddleOCR(det_model_dir='{your_det_model_dir}', rec_model_dir='{your_rec_model_dir}',
|
||||
rec_char_dict_path='{your_rec_char_dict_path}', cls_model_dir='{your_cls_model_dir}',
|
||||
use_angle_cls=True)
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
for line in result:
|
||||
|
@ -249,6 +300,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -269,8 +321,10 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_
|
|||
### 4.1 网络图片
|
||||
|
||||
- 代码使用
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
|
@ -281,7 +335,9 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
download_with_progressbar(img_path, 'tmp.jpg')
|
||||
image = Image.open('tmp.jpg').convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
|
@ -289,15 +345,21 @@ im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc
|
|||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
- 命令行模式
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
### 4.2 numpy数组
|
||||
|
||||
仅通过代码使用时支持numpy数组作为输入
|
||||
|
||||
```python
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
|
@ -310,6 +372,7 @@ for line in result:
|
|||
|
||||
# 显示结果
|
||||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
|
@ -355,3 +418,5 @@ im_show.save('result.jpg')
|
|||
| det | 前向时使用启动检测 | TRUE |
|
||||
| rec | 前向时是否启动识别 | TRUE |
|
||||
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
|
||||
| show_log | 是否打印det和rec等信息 | FALSE |
|
||||
| type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr |
|
||||
|
|
|
@ -154,12 +154,12 @@ Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest si
|
|||
|
||||
If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
|
||||
```
|
||||
|
||||
If you want to use the CPU for prediction, execute the command as follows
|
||||
```
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/22.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
|
||||
```
|
||||
|
||||
<a name="DB_DETECTION"></a>
|
||||
|
@ -230,7 +230,7 @@ First, convert the model saved in the SAST text detection training process into
|
|||
python3 tools/export_model.py -c configs/det/det_r50_vd_sast_totaltext.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.save_inference_dir=./inference/det_sast_tt
|
||||
```
|
||||
|
||||
**For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`**, run the following command:
|
||||
For SAST curved text detection model inference, you need to set the parameter `--det_algorithm="SAST"` and `--det_sast_polygon=True`, run the following command:
|
||||
|
||||
```
|
||||
python3 tools/infer/predict_det.py --det_algorithm="SAST" --image_dir="./doc/imgs_en/img623.jpg" --det_model_dir="./inference/det_sast_tt/" --det_sast_polygon=True
|
||||
|
|
|
@ -329,6 +329,7 @@ There are two ways to create the required configuration file::
|
|||
...
|
||||
|
||||
```
|
||||
Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
|
||||
|
||||
2. Manually modify the configuration file
|
||||
|
||||
|
@ -375,7 +376,9 @@ Currently, the multi-language algorithms supported by PaddleOCR are:
|
|||
|
||||
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
|
||||
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
|
||||
* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
|
||||
* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
|
||||
|
||||
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
|
||||
|
||||
|
|
|
@ -15,8 +15,6 @@
|
|||
- 2020.6.8 Add [datasets](./datasets_en.md) and keep updating
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.6.5 Support exporting `attention` model to `inference_model`
|
||||
- 2020.6.5 Support separate prediction and recognition, output result score
|
||||
- 2020.5.30 Provide Lightweight Chinese OCR online experience
|
||||
- 2020.5.30 Model prediction and training support on Windows system
|
||||
- 2020.5.30 Open source general Chinese OCR model
|
||||
|
|
|
@ -305,7 +305,8 @@ paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-f
|
|||
Support numpy array as input only when used by code
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
import cv2
|
||||
from paddleocr import PaddleOCR, draw_ocr, download_with_progressbar
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
|
@ -316,7 +317,9 @@ for line in result:
|
|||
|
||||
# show result
|
||||
from PIL import Image
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
download_with_progressbar(img_path, 'tmp.jpg')
|
||||
image = Image.open('tmp.jpg').convert('RGB')
|
||||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
|
@ -362,3 +365,5 @@ im_show.save('result.jpg')
|
|||
| det | Enable detction when `ppocr.ocr` func exec | TRUE |
|
||||
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
|
||||
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
|
||||
| show_log | Whether to print log in det and rec | FALSE |
|
||||
| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
|
BIN
doc/joinus.PNG
Before Width: | Height: | Size: 78 KiB After Width: | Height: | Size: 541 KiB |
After Width: | Height: | Size: 263 KiB |
After Width: | Height: | Size: 672 KiB |
After Width: | Height: | Size: 672 KiB |
After Width: | Height: | Size: 1.5 MiB |
After Width: | Height: | Size: 1.4 MiB |
After Width: | Height: | Size: 2.5 MiB |
After Width: | Height: | Size: 521 KiB |
After Width: | Height: | Size: 146 KiB |
After Width: | Height: | Size: 24 KiB |
After Width: | Height: | Size: 552 KiB |
After Width: | Height: | Size: 416 KiB |
220
paddleocr.py
|
@ -19,27 +19,29 @@ __dir__ = os.path.dirname(__file__)
|
|||
sys.path.append(os.path.join(__dir__, ''))
|
||||
|
||||
import cv2
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.infer import predict_system
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from tools.infer.utility import draw_ocr, init_args, str2bool
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
|
||||
from tools.infer.utility import draw_ocr, str2bool
|
||||
from ppstructure.utility import init_args, draw_structure_result
|
||||
from ppstructure.predict_system import OCRSystem, save_structure_res
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
|
||||
|
||||
model_urls = {
|
||||
'det': {
|
||||
'ch':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'en':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
|
||||
},
|
||||
'rec': {
|
||||
'ch': {
|
||||
|
@ -111,62 +113,25 @@ model_urls = {
|
|||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
||||
},
|
||||
'structure': {
|
||||
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_dict.txt'
|
||||
}
|
||||
},
|
||||
'cls':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
|
||||
'table': {
|
||||
'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
|
||||
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
|
||||
}
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
VERSION = '2.1'
|
||||
VERSION = '2.2'
|
||||
SUPPORT_REC_MODEL = ['CRNN']
|
||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
|
@ -174,9 +139,10 @@ def parse_args(mMain=True):
|
|||
parser.add_argument("--lang", type=str, default='ch')
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
parser.add_argument("--type", type=str, default='ocr')
|
||||
|
||||
for action in parser._actions:
|
||||
if action.dest == 'rec_char_dict_path':
|
||||
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
|
||||
action.default = None
|
||||
if mMain:
|
||||
return parser.parse_args()
|
||||
|
@ -187,17 +153,7 @@ def parse_args(mMain=True):
|
|||
return argparse.Namespace(**inference_args_dict)
|
||||
|
||||
|
||||
class PaddleOCR(predict_system.TextSystem):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
paddleocr package
|
||||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
postprocess_params = parse_args(mMain=False)
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||
lang = postprocess_params.lang
|
||||
def parse_lang(lang):
|
||||
latin_lang = [
|
||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||
|
@ -226,43 +182,55 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
model_urls['rec'].keys(), lang)
|
||||
if lang == "ch":
|
||||
det_lang = "ch"
|
||||
elif lang == 'structure':
|
||||
det_lang = 'structure'
|
||||
else:
|
||||
det_lang = "en"
|
||||
use_inner_dict = False
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
use_inner_dict = True
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
return lang, det_lang
|
||||
|
||||
|
||||
class PaddleOCR(predict_system.TextSystem):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
paddleocr package
|
||||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
self.use_angle_cls = params.use_angle_cls
|
||||
lang, det_lang = parse_lang(params.lang)
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'det', det_lang)
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'rec', lang)
|
||||
if postprocess_params.cls_model_dir is None:
|
||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
print(postprocess_params)
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir,
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
maybe_download(postprocess_params.rec_model_dir,
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
|
||||
model_urls['cls'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.cls_model_dir, cls_url)
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
sys.exit(0)
|
||||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
if params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||
sys.exit(0)
|
||||
if use_inner_dict:
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
|
||||
|
||||
print(params)
|
||||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
super().__init__(params)
|
||||
|
||||
def ocr(self, img, det=True, rec=True, cls=True):
|
||||
"""
|
||||
|
@ -316,11 +284,64 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
return rec_res
|
||||
|
||||
|
||||
class PPStructure(OCRSystem):
|
||||
def __init__(self, **kwargs):
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
if not params.show_log:
|
||||
logger.setLevel(logging.INFO)
|
||||
lang, det_lang = parse_lang(params.lang)
|
||||
|
||||
# init model dir
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
|
||||
model_urls['table']['url'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.table_model_dir, table_url)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
|
||||
if params.table_char_dict_path is None:
|
||||
params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
|
||||
|
||||
print(params)
|
||||
super().__init__(params)
|
||||
|
||||
def __call__(self, img):
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
download_with_progressbar(img, 'tmp.jpg')
|
||||
img = 'tmp.jpg'
|
||||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
with open(image_file, 'rb') as f:
|
||||
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
if isinstance(img, np.ndarray) and len(img.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
|
||||
res = super().__call__(img)
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
if image_dir.startswith('http'):
|
||||
if is_link(image_dir):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
|
@ -328,14 +349,29 @@ def main():
|
|||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
if args.type == 'ocr':
|
||||
engine = PaddleOCR(**(args.__dict__))
|
||||
elif args.type == 'structure':
|
||||
engine = PPStructure(**(args.__dict__))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
ocr_engine = PaddleOCR(**(args.__dict__))
|
||||
for img_path in image_file_list:
|
||||
img_name = os.path.basename(img_path).split('.')[0]
|
||||
logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10))
|
||||
result = ocr_engine.ocr(img_path,
|
||||
if args.type == 'ocr':
|
||||
result = engine.ocr(img_path,
|
||||
det=args.det,
|
||||
rec=args.rec,
|
||||
cls=args.use_angle_cls)
|
||||
if result is not None:
|
||||
for line in result:
|
||||
logger.info(line)
|
||||
elif args.type == 'structure':
|
||||
result = engine(img_path)
|
||||
save_structure_res(result, args.output, img_name)
|
||||
|
||||
for item in result:
|
||||
item.pop('img')
|
||||
logger.info(item)
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ from ppocr.data.imaug import transform, create_operators
|
|||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
from ppocr.data.pubtab_dataset import PubTabDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -55,7 +56,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
|
|
@ -23,12 +23,14 @@ from .random_crop_data import EastRandomCropData, PSERandomCrop
|
|||
|
||||
from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, PILResize, CVResize
|
||||
from .randaugment import RandAugment
|
||||
from .copy_paste import CopyPaste
|
||||
from .operators import *
|
||||
from .label_ops import *
|
||||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
from .gen_table_mask import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -0,0 +1,166 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import cv2
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
from ppocr.data.imaug.iaa_augment import IaaAugment
|
||||
from ppocr.data.imaug.random_crop_data import is_poly_outside_rect
|
||||
from tools.infer.utility import get_rotate_crop_image
|
||||
|
||||
|
||||
class CopyPaste(object):
|
||||
def __init__(self, objects_paste_ratio=0.2, limit_paste=True, **kwargs):
|
||||
self.ext_data_num = 1
|
||||
self.objects_paste_ratio = objects_paste_ratio
|
||||
self.limit_paste = limit_paste
|
||||
augmenter_args = [{'type': 'Resize', 'args': {'size': [0.5, 3]}}]
|
||||
self.aug = IaaAugment(augmenter_args)
|
||||
|
||||
def __call__(self, data):
|
||||
src_img = data['image']
|
||||
src_polys = data['polys'].tolist()
|
||||
src_ignores = data['ignore_tags'].tolist()
|
||||
ext_data = data['ext_data'][0]
|
||||
ext_image = ext_data['image']
|
||||
ext_polys = ext_data['polys']
|
||||
ext_ignores = ext_data['ignore_tags']
|
||||
|
||||
indexs = [i for i in range(len(ext_ignores)) if not ext_ignores[i]]
|
||||
select_num = max(
|
||||
1, min(int(self.objects_paste_ratio * len(ext_polys)), 30))
|
||||
|
||||
random.shuffle(indexs)
|
||||
select_idxs = indexs[:select_num]
|
||||
select_polys = ext_polys[select_idxs]
|
||||
select_ignores = ext_ignores[select_idxs]
|
||||
|
||||
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
||||
ext_image = cv2.cvtColor(ext_image, cv2.COLOR_BGR2RGB)
|
||||
src_img = Image.fromarray(src_img).convert('RGBA')
|
||||
for poly, tag in zip(select_polys, select_ignores):
|
||||
box_img = get_rotate_crop_image(ext_image, poly)
|
||||
|
||||
src_img, box = self.paste_img(src_img, box_img, src_polys)
|
||||
if box is not None:
|
||||
src_polys.append(box)
|
||||
src_ignores.append(tag)
|
||||
src_img = cv2.cvtColor(np.array(src_img), cv2.COLOR_RGB2BGR)
|
||||
h, w = src_img.shape[:2]
|
||||
src_polys = np.array(src_polys)
|
||||
src_polys[:, :, 0] = np.clip(src_polys[:, :, 0], 0, w)
|
||||
src_polys[:, :, 1] = np.clip(src_polys[:, :, 1], 0, h)
|
||||
data['image'] = src_img
|
||||
data['polys'] = src_polys
|
||||
data['ignore_tags'] = np.array(src_ignores)
|
||||
return data
|
||||
|
||||
def paste_img(self, src_img, box_img, src_polys):
|
||||
box_img_pil = Image.fromarray(box_img).convert('RGBA')
|
||||
src_w, src_h = src_img.size
|
||||
box_w, box_h = box_img_pil.size
|
||||
|
||||
angle = np.random.randint(0, 360)
|
||||
box = np.array([[[0, 0], [box_w, 0], [box_w, box_h], [0, box_h]]])
|
||||
box = rotate_bbox(box_img, box, angle)[0]
|
||||
box_img_pil = box_img_pil.rotate(angle, expand=1)
|
||||
box_w, box_h = box_img_pil.width, box_img_pil.height
|
||||
if src_w - box_w < 0 or src_h - box_h < 0:
|
||||
return src_img, None
|
||||
|
||||
paste_x, paste_y = self.select_coord(src_polys, box, src_w - box_w,
|
||||
src_h - box_h)
|
||||
if paste_x is None:
|
||||
return src_img, None
|
||||
box[:, 0] += paste_x
|
||||
box[:, 1] += paste_y
|
||||
r, g, b, A = box_img_pil.split()
|
||||
src_img.paste(box_img_pil, (paste_x, paste_y), mask=A)
|
||||
|
||||
return src_img, box
|
||||
|
||||
def select_coord(self, src_polys, box, endx, endy):
|
||||
if self.limit_paste:
|
||||
xmin, ymin, xmax, ymax = box[:, 0].min(), box[:, 1].min(
|
||||
), box[:, 0].max(), box[:, 1].max()
|
||||
for _ in range(50):
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
xmin1 = xmin + paste_x
|
||||
xmax1 = xmax + paste_x
|
||||
ymin1 = ymin + paste_y
|
||||
ymax1 = ymax + paste_y
|
||||
|
||||
num_poly_in_rect = 0
|
||||
for poly in src_polys:
|
||||
if not is_poly_outside_rect(poly, xmin1, ymin1,
|
||||
xmax1 - xmin1, ymax1 - ymin1):
|
||||
num_poly_in_rect += 1
|
||||
break
|
||||
if num_poly_in_rect == 0:
|
||||
return paste_x, paste_y
|
||||
return None, None
|
||||
else:
|
||||
paste_x = random.randint(0, endx)
|
||||
paste_y = random.randint(0, endy)
|
||||
return paste_x, paste_y
|
||||
|
||||
|
||||
def get_union(pD, pG):
|
||||
return Polygon(pD).union(Polygon(pG)).area
|
||||
|
||||
|
||||
def get_intersection_over_union(pD, pG):
|
||||
return get_intersection(pD, pG) / get_union(pD, pG)
|
||||
|
||||
|
||||
def get_intersection(pD, pG):
|
||||
return Polygon(pD).intersection(Polygon(pG)).area
|
||||
|
||||
|
||||
def rotate_bbox(img, text_polys, angle, scale=1):
|
||||
"""
|
||||
from https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/augment.py
|
||||
Args:
|
||||
img: np.ndarray
|
||||
text_polys: np.ndarray N*4*2
|
||||
angle: int
|
||||
scale: int
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
w = img.shape[1]
|
||||
h = img.shape[0]
|
||||
|
||||
rangle = np.deg2rad(angle)
|
||||
nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w))
|
||||
nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w))
|
||||
rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
|
||||
rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
|
||||
rot_mat[0, 2] += rot_move[0]
|
||||
rot_mat[1, 2] += rot_move[1]
|
||||
|
||||
# ---------------------- rotate box ----------------------
|
||||
rot_text_polys = list()
|
||||
for bbox in text_polys:
|
||||
point1 = np.dot(rot_mat, np.array([bbox[0, 0], bbox[0, 1], 1]))
|
||||
point2 = np.dot(rot_mat, np.array([bbox[1, 0], bbox[1, 1], 1]))
|
||||
point3 = np.dot(rot_mat, np.array([bbox[2, 0], bbox[2, 1], 1]))
|
||||
point4 = np.dot(rot_mat, np.array([bbox[3, 0], bbox[3, 1], 1]))
|
||||
rot_text_polys.append([point1, point2, point3, point4])
|
||||
return np.array(rot_text_polys, dtype=np.float32)
|
|
@ -0,0 +1,244 @@
|
|||
"""
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import six
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
|
||||
class GenTableMask(object):
|
||||
""" gen table mask """
|
||||
|
||||
def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
|
||||
self.shrink_h_max = 5
|
||||
self.shrink_w_max = 5
|
||||
self.mask_type = mask_type
|
||||
|
||||
def projection(self, erosion, h, w, spilt_threshold=0):
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
return box_list, projection_map
|
||||
|
||||
def projection_cx(self, box_img):
|
||||
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
|
||||
h, w = box_gray_img.shape
|
||||
# 灰度图片进行二值化处理
|
||||
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
|
||||
# 纵向腐蚀
|
||||
if h < w:
|
||||
kernel = np.ones((2, 1), np.uint8)
|
||||
erode = cv2.erode(thresh1, kernel, iterations=1)
|
||||
else:
|
||||
erode = thresh1
|
||||
# 水平膨胀
|
||||
kernel = np.ones((1, 5), np.uint8)
|
||||
erosion = cv2.dilate(erode, kernel, iterations=1)
|
||||
# 水平投影
|
||||
projection_map = np.ones_like(erosion)
|
||||
project_val_array = [0 for _ in range(0, h)]
|
||||
|
||||
for j in range(0, h):
|
||||
for i in range(0, w):
|
||||
if erosion[j, i] == 255:
|
||||
project_val_array[j] += 1
|
||||
# 根据数组,获取切割点
|
||||
start_idx = 0 # 记录进入字符区的索引
|
||||
end_idx = 0 # 记录进入空白区域的索引
|
||||
in_text = False # 是否遍历到了字符区内
|
||||
box_list = []
|
||||
spilt_threshold = 0
|
||||
for i in range(len(project_val_array)):
|
||||
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
|
||||
in_text = True
|
||||
start_idx = i
|
||||
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
|
||||
end_idx = i
|
||||
in_text = False
|
||||
if end_idx - start_idx <= 2:
|
||||
continue
|
||||
box_list.append((start_idx, end_idx + 1))
|
||||
|
||||
if in_text:
|
||||
box_list.append((start_idx, h - 1))
|
||||
# 绘制投影直方图
|
||||
for j in range(0, h):
|
||||
for i in range(0, project_val_array[j]):
|
||||
projection_map[j, i] = 0
|
||||
split_bbox_list = []
|
||||
if len(box_list) > 1:
|
||||
for i, (h_start, h_end) in enumerate(box_list):
|
||||
if i == 0:
|
||||
h_start = 0
|
||||
if i == len(box_list):
|
||||
h_end = h
|
||||
word_img = erosion[h_start:h_end + 1, :]
|
||||
word_h, word_w = word_img.shape
|
||||
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
|
||||
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
|
||||
if h_start > 0:
|
||||
h_start -= 1
|
||||
h_end += 1
|
||||
word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
|
||||
split_bbox_list.append([w_start, h_start, w_end, h_end])
|
||||
else:
|
||||
split_bbox_list.append([0, 0, w, h])
|
||||
return split_bbox_list
|
||||
|
||||
def shrink_bbox(self, bbox):
|
||||
left, top, right, bottom = bbox
|
||||
sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
|
||||
sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
|
||||
left_new = left + sh_w
|
||||
right_new = right - sh_w
|
||||
top_new = top + sh_h
|
||||
bottom_new = bottom - sh_h
|
||||
if left_new >= right_new:
|
||||
left_new = left
|
||||
right_new = right
|
||||
if top_new >= bottom_new:
|
||||
top_new = top
|
||||
bottom_new = bottom
|
||||
return [left_new, top_new, right_new, bottom_new]
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
cells = data['cells']
|
||||
height, width = img.shape[0:2]
|
||||
if self.mask_type == 1:
|
||||
mask_img = np.zeros((height, width), dtype=np.float32)
|
||||
else:
|
||||
mask_img = np.zeros((height, width, 3), dtype=np.float32)
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
left, top, right, bottom = bbox
|
||||
box_img = img[top:bottom, left:right, :].copy()
|
||||
split_bbox_list = self.projection_cx(box_img)
|
||||
for sno in range(len(split_bbox_list)):
|
||||
split_bbox_list[sno][0] += left
|
||||
split_bbox_list[sno][1] += top
|
||||
split_bbox_list[sno][2] += left
|
||||
split_bbox_list[sno][3] += top
|
||||
|
||||
for sno in range(len(split_bbox_list)):
|
||||
left, top, right, bottom = split_bbox_list[sno]
|
||||
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
|
||||
if self.mask_type == 1:
|
||||
mask_img[top:bottom, left:right] = 1.0
|
||||
data['mask_img'] = mask_img
|
||||
else:
|
||||
mask_img[top:bottom, left:right, :] = (255, 255, 255)
|
||||
data['image'] = mask_img
|
||||
return data
|
||||
|
||||
class ResizeTableImage(object):
|
||||
def __init__(self, max_len, **kwargs):
|
||||
super(ResizeTableImage, self).__init__()
|
||||
self.max_len = max_len
|
||||
|
||||
def get_img_bbox(self, cells):
|
||||
bbox_list = []
|
||||
if len(cells) == 0:
|
||||
return bbox_list
|
||||
cell_num = len(cells)
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in cells[cno]:
|
||||
bbox = cells[cno]['bbox']
|
||||
bbox_list.append(bbox)
|
||||
return bbox_list
|
||||
|
||||
def resize_img_table(self, img, bbox_list, max_len):
|
||||
height, width = img.shape[0:2]
|
||||
ratio = max_len / (max(height, width) * 1.0)
|
||||
resize_h = int(height * ratio)
|
||||
resize_w = int(width * ratio)
|
||||
img_new = cv2.resize(img, (resize_w, resize_h))
|
||||
bbox_list_new = []
|
||||
for bno in range(len(bbox_list)):
|
||||
left, top, right, bottom = bbox_list[bno].copy()
|
||||
left = int(left * ratio)
|
||||
top = int(top * ratio)
|
||||
right = int(right * ratio)
|
||||
bottom = int(bottom * ratio)
|
||||
bbox_list_new.append([left, top, right, bottom])
|
||||
return img_new, bbox_list_new
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
if 'cells' not in data:
|
||||
cells = []
|
||||
else:
|
||||
cells = data['cells']
|
||||
bbox_list = self.get_img_bbox(cells)
|
||||
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
|
||||
data['image'] = img_new
|
||||
cell_num = len(cells)
|
||||
bno = 0
|
||||
for cno in range(cell_num):
|
||||
if "bbox" in data['cells'][cno]:
|
||||
data['cells'][cno]['bbox'] = bbox_list_new[bno]
|
||||
bno += 1
|
||||
data['max_len'] = self.max_len
|
||||
return data
|
||||
|
||||
class PaddingTableImage(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(PaddingTableImage, self).__init__()
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
max_len = data['max_len']
|
||||
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
|
||||
height, width = img.shape[0:2]
|
||||
padding_img[0:height, 0:width, :] = img.copy()
|
||||
data['image'] = padding_img
|
||||
return data
|
||||
|
|
@ -19,6 +19,7 @@ from __future__ import unicode_literals
|
|||
|
||||
import numpy as np
|
||||
import string
|
||||
import json
|
||||
|
||||
|
||||
class ClsLabelEncode(object):
|
||||
|
@ -39,7 +40,6 @@ class DetLabelEncode(object):
|
|||
pass
|
||||
|
||||
def __call__(self, data):
|
||||
import json
|
||||
label = data['label']
|
||||
label = json.loads(label)
|
||||
nBox = len(label)
|
||||
|
@ -53,6 +53,8 @@ class DetLabelEncode(object):
|
|||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
if len(boxes) == 0:
|
||||
return None
|
||||
boxes = self.expand_points_num(boxes)
|
||||
boxes = np.array(boxes, dtype=np.float32)
|
||||
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||
|
@ -379,3 +381,171 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
|||
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
|
||||
% beg_or_end
|
||||
return idx
|
||||
|
||||
|
||||
class TableLabelEncode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
max_elem_length,
|
||||
max_cell_num,
|
||||
character_dict_path,
|
||||
span_weight=1.0,
|
||||
**kwargs):
|
||||
self.max_text_length = max_text_length
|
||||
self.max_elem_length = max_elem_length
|
||||
self.max_cell_num = max_cell_num
|
||||
list_character, list_elem = self.load_char_elem_dict(
|
||||
character_dict_path)
|
||||
list_character = self.add_special_char(list_character)
|
||||
list_elem = self.add_special_char(list_elem)
|
||||
self.dict_character = {}
|
||||
for i, char in enumerate(list_character):
|
||||
self.dict_character[char] = i
|
||||
self.dict_elem = {}
|
||||
for i, elem in enumerate(list_elem):
|
||||
self.dict_elem[elem] = i
|
||||
self.span_weight = span_weight
|
||||
|
||||
def load_char_elem_dict(self, character_dict_path):
|
||||
list_character = []
|
||||
list_elem = []
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
|
||||
character_num = int(substr[0])
|
||||
elem_num = int(substr[1])
|
||||
|
||||
for cno in range(1, 1 + character_num):
|
||||
character = lines[cno].decode('utf-8').strip("\r\n")
|
||||
list_character.append(character)
|
||||
for eno in range(1 + character_num, 1 + character_num + elem_num):
|
||||
elem = lines[eno].decode('utf-8').strip("\r\n")
|
||||
list_elem.append(elem)
|
||||
return list_character, list_elem
|
||||
|
||||
def add_special_char(self, list_character):
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
list_character = [self.beg_str] + list_character + [self.end_str]
|
||||
return list_character
|
||||
|
||||
def get_span_idx_list(self):
|
||||
span_idx_list = []
|
||||
for elem in self.dict_elem:
|
||||
if 'span' in elem:
|
||||
span_idx_list.append(self.dict_elem[elem])
|
||||
return span_idx_list
|
||||
|
||||
def __call__(self, data):
|
||||
cells = data['cells']
|
||||
structure = data['structure']['tokens']
|
||||
structure = self.encode(structure, 'elem')
|
||||
if structure is None:
|
||||
return None
|
||||
elem_num = len(structure)
|
||||
structure = [0] + structure + [len(self.dict_elem) - 1]
|
||||
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
|
||||
)
|
||||
structure = np.array(structure)
|
||||
data['structure'] = structure
|
||||
elem_char_idx1 = self.dict_elem['<td>']
|
||||
elem_char_idx2 = self.dict_elem['<td']
|
||||
span_idx_list = self.get_span_idx_list()
|
||||
td_idx_list = np.logical_or(structure == elem_char_idx1,
|
||||
structure == elem_char_idx2)
|
||||
td_idx_list = np.where(td_idx_list)[0]
|
||||
|
||||
structure_mask = np.ones(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
|
||||
bbox_list_mask = np.zeros(
|
||||
(self.max_elem_length + 2, 1), dtype=np.float32)
|
||||
img_height, img_width, img_ch = data['image'].shape
|
||||
if len(span_idx_list) > 0:
|
||||
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
|
||||
span_weight = min(max(span_weight, 1.0), self.span_weight)
|
||||
for cno in range(len(cells)):
|
||||
if 'bbox' in cells[cno]:
|
||||
bbox = cells[cno]['bbox'].copy()
|
||||
bbox[0] = bbox[0] * 1.0 / img_width
|
||||
bbox[1] = bbox[1] * 1.0 / img_height
|
||||
bbox[2] = bbox[2] * 1.0 / img_width
|
||||
bbox[3] = bbox[3] * 1.0 / img_height
|
||||
td_idx = td_idx_list[cno]
|
||||
bbox_list[td_idx] = bbox
|
||||
bbox_list_mask[td_idx] = 1.0
|
||||
cand_span_idx = td_idx + 1
|
||||
if cand_span_idx < (self.max_elem_length + 2):
|
||||
if structure[cand_span_idx] in span_idx_list:
|
||||
structure_mask[cand_span_idx] = span_weight
|
||||
|
||||
data['bbox_list'] = bbox_list
|
||||
data['bbox_list_mask'] = bbox_list_mask
|
||||
data['structure_mask'] = structure_mask
|
||||
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
|
||||
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
|
||||
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
|
||||
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
|
||||
data['sp_tokens'] = np.array([
|
||||
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
|
||||
elem_char_idx1, elem_char_idx2, self.max_text_length,
|
||||
self.max_elem_length, self.max_cell_num, elem_num
|
||||
])
|
||||
return data
|
||||
|
||||
def encode(self, text, char_or_elem):
|
||||
"""convert text-label into text-index.
|
||||
"""
|
||||
if char_or_elem == "char":
|
||||
max_len = self.max_text_length
|
||||
current_dict = self.dict_character
|
||||
else:
|
||||
max_len = self.max_elem_length
|
||||
current_dict = self.dict_elem
|
||||
if len(text) > max_len:
|
||||
return None
|
||||
if len(text) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in current_dict:
|
||||
return None
|
||||
text_list.append(current_dict[char])
|
||||
if len(text_list) == 0:
|
||||
if char_or_elem == "char":
|
||||
return [self.dict_character['space']]
|
||||
else:
|
||||
return None
|
||||
return text_list
|
||||
|
||||
def get_ignored_tokens(self, char_or_elem):
|
||||
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
|
||||
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
|
||||
return [beg_idx, end_idx]
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
|
||||
if char_or_elem == "char":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_character[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_character[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
|
||||
% beg_or_end
|
||||
elif char_or_elem == "elem":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict_elem[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict_elem[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
|
||||
% beg_or_end
|
||||
else:
|
||||
assert False, "Unsupport type %s in char_or_elem" \
|
||||
% char_or_elem
|
||||
return idx
|
||||
|
|
|
@ -195,7 +195,7 @@ class DetResizeForTest(object):
|
|||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, _ = img.shape
|
||||
h, w, c = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
|
@ -206,7 +206,7 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
else:
|
||||
elif self.limit_type == 'min':
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
|
@ -214,6 +214,10 @@ class DetResizeForTest(object):
|
|||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
elif self.limit_type == 'resize_long':
|
||||
ratio = float(limit_side_len) / max(h,w)
|
||||
else:
|
||||
raise Exception('not support limit type, image ')
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
from paddle.io import Dataset
|
||||
import json
|
||||
|
||||
from .imaug import transform, create_operators
|
||||
|
||||
|
||||
class PubTabDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PubTabDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_path = dataset_config.pop('label_file_path')
|
||||
|
||||
self.data_dir = dataset_config['data_dir']
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
self.do_hard_select = False
|
||||
if 'hard_select' in loader_config:
|
||||
self.do_hard_select = loader_config['hard_select']
|
||||
self.hard_prob = loader_config['hard_prob']
|
||||
if self.do_hard_select:
|
||||
self.img_select_prob = self.load_hard_select_prob()
|
||||
self.table_select_type = None
|
||||
if 'table_select_type' in loader_config:
|
||||
self.table_select_type = loader_config['table_select_type']
|
||||
self.table_select_prob = loader_config['table_select_prob']
|
||||
|
||||
self.seed = seed
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_path)
|
||||
with open(label_file_path, "rb") as f:
|
||||
self.data_lines = f.readlines()
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
try:
|
||||
data_line = self.data_lines[idx]
|
||||
data_line = data_line.decode('utf-8').strip("\n")
|
||||
info = json.loads(data_line)
|
||||
file_name = info['filename']
|
||||
select_flag = True
|
||||
if self.do_hard_select:
|
||||
prob = self.img_select_prob[file_name]
|
||||
if prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if self.table_select_type:
|
||||
structure = info['html']['structure']['tokens'].copy()
|
||||
structure_str = ''.join(structure)
|
||||
table_type = "simple"
|
||||
if 'colspan' in structure_str or 'rowspan' in structure_str:
|
||||
table_type = "complex"
|
||||
if table_type == "complex":
|
||||
if self.table_select_prob < random.uniform(0, 1):
|
||||
select_flag = False
|
||||
|
||||
if select_flag:
|
||||
cells = info['html']['cells'].copy()
|
||||
structure = info['html']['structure'].copy()
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'cells': cells, 'structure':structure}
|
||||
if not os.path.exists(img_path):
|
||||
raise Exception("{} does not exist!".format(img_path))
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
else:
|
||||
outs = None
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
|
@ -69,12 +69,42 @@ class SimpleDataSet(Dataset):
|
|||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def get_ext_data(self):
|
||||
ext_data_num = 0
|
||||
for op in self.ops:
|
||||
if hasattr(op, 'ext_data_num'):
|
||||
ext_data_num = getattr(op, 'ext_data_num')
|
||||
break
|
||||
load_data_ops = self.ops[:2]
|
||||
ext_data = []
|
||||
|
||||
while len(ext_data) < ext_data_num:
|
||||
file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
|
||||
))]
|
||||
data_line = self.data_lines[file_idx]
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
data = {'img_path': img_path, 'label': label}
|
||||
if not os.path.exists(img_path):
|
||||
continue
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data = transform(data, load_data_ops)
|
||||
if data is None:
|
||||
continue
|
||||
ext_data.append(data)
|
||||
return ext_data
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
data_line = data_line.decode('utf-8')
|
||||
substr = data_line.strip("\n").strip("\r").split(self.delimiter)
|
||||
substr = data_line.strip("\n").split(self.delimiter)
|
||||
file_name = substr[0]
|
||||
label = substr[1]
|
||||
img_path = os.path.join(self.data_dir, file_name)
|
||||
|
@ -84,6 +114,7 @@ class SimpleDataSet(Dataset):
|
|||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
data['ext_data'] = self.get_ext_data()
|
||||
outs = transform(data, self.ops)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
|
|
|
@ -38,11 +38,15 @@ from .basic_loss import DistanceLoss
|
|||
# combined loss function
|
||||
from .combined_loss import CombinedLoss
|
||||
|
||||
# table loss
|
||||
from .table_att_loss import TableAttentionLoss
|
||||
|
||||
def build_loss(config):
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss']
|
||||
|
||||
'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss', 'TableAttentionLoss'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -54,6 +54,27 @@ class CELoss(nn.Layer):
|
|||
return loss
|
||||
|
||||
|
||||
class KLJSLoss(object):
|
||||
def __init__(self, mode='kl'):
|
||||
assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
|
||||
self.mode = mode
|
||||
|
||||
def __call__(self, p1, p2, reduction="mean"):
|
||||
|
||||
loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
|
||||
|
||||
if self.mode.lower() == "js":
|
||||
loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
|
||||
loss *= 0.5
|
||||
if reduction == "mean":
|
||||
loss = paddle.mean(loss, axis=[1,2])
|
||||
elif reduction=="none" or reduction is None:
|
||||
return loss
|
||||
else:
|
||||
loss = paddle.sum(loss, axis=[1,2])
|
||||
|
||||
return loss
|
||||
|
||||
class DMLLoss(nn.Layer):
|
||||
"""
|
||||
DMLLoss
|
||||
|
@ -70,16 +91,20 @@ class DMLLoss(nn.Layer):
|
|||
else:
|
||||
self.act = None
|
||||
|
||||
self.jskl_loss = KLJSLoss(mode="js")
|
||||
|
||||
def forward(self, out1, out2):
|
||||
if self.act is not None:
|
||||
out1 = self.act(out1)
|
||||
out2 = self.act(out2)
|
||||
|
||||
if len(out1.shape) < 2:
|
||||
log_out1 = paddle.log(out1)
|
||||
log_out2 = paddle.log(out2)
|
||||
loss = (F.kl_div(
|
||||
log_out1, out2, reduction='batchmean') + F.kl_div(
|
||||
log_out2, out1, reduction='batchmean')) / 2.0
|
||||
else:
|
||||
loss = self.jskl_loss(out1, out2)
|
||||
return loss
|
||||
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import paddle.nn as nn
|
|||
|
||||
from .distillation_loss import DistillationCTCLoss
|
||||
from .distillation_loss import DistillationDMLLoss
|
||||
from .distillation_loss import DistillationDistanceLoss
|
||||
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
|
||||
|
||||
|
||||
class CombinedLoss(nn.Layer):
|
||||
|
@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
|
|||
|
||||
def forward(self, input, batch, **kargs):
|
||||
loss_dict = {}
|
||||
loss_all = 0.
|
||||
for idx, loss_func in enumerate(self.loss_func):
|
||||
loss = loss_func(input, batch, **kargs)
|
||||
if isinstance(loss, paddle.Tensor):
|
||||
loss = {"loss_{}_{}".format(str(loss), idx): loss}
|
||||
weight = self.loss_weight[idx]
|
||||
loss = {
|
||||
"{}_{}".format(key, idx): loss[key] * weight
|
||||
for key in loss
|
||||
}
|
||||
loss_dict.update(loss)
|
||||
loss_dict["loss"] = paddle.add_n(list(loss_dict.values()))
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
loss_all += loss[key] * weight
|
||||
else:
|
||||
loss_dict["{}_{}".format(key, idx)] = loss[key]
|
||||
loss_dict["loss"] = loss_all
|
||||
return loss_dict
|
||||
|
|
|
@ -14,23 +14,76 @@
|
|||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from .rec_ctc_loss import CTCLoss
|
||||
from .basic_loss import DMLLoss
|
||||
from .basic_loss import DistanceLoss
|
||||
from .det_db_loss import DBLoss
|
||||
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
|
||||
|
||||
|
||||
def _sum_loss(loss_dict):
|
||||
if "loss" in loss_dict.keys():
|
||||
return loss_dict
|
||||
else:
|
||||
loss_dict["loss"] = 0.
|
||||
for k, value in loss_dict.items():
|
||||
if k == "loss":
|
||||
continue
|
||||
else:
|
||||
loss_dict["loss"] += value
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDMLLoss(DMLLoss):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, model_name_pairs=[], act=None, key=None,
|
||||
name="loss_dml"):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
act=None,
|
||||
key=None,
|
||||
maps_name=None,
|
||||
name="dml"):
|
||||
super().__init__(act=act)
|
||||
assert isinstance(model_name_pairs, list)
|
||||
self.key = key
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
|
||||
self.name = name
|
||||
self.maps_name = self._check_maps_name(maps_name)
|
||||
|
||||
def _check_model_name_pairs(self, model_name_pairs):
|
||||
if not isinstance(model_name_pairs, list):
|
||||
return []
|
||||
elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
|
||||
return model_name_pairs
|
||||
else:
|
||||
return [model_name_pairs]
|
||||
|
||||
def _check_maps_name(self, maps_name):
|
||||
if maps_name is None:
|
||||
return None
|
||||
elif type(maps_name) == str:
|
||||
return [maps_name]
|
||||
elif type(maps_name) == list:
|
||||
return [maps_name]
|
||||
else:
|
||||
return None
|
||||
|
||||
def _slice_out(self, outs):
|
||||
new_outs = {}
|
||||
for k in self.maps_name:
|
||||
if k == "thrink_maps":
|
||||
new_outs[k] = outs[:, 0, :, :]
|
||||
elif k == "threshold_maps":
|
||||
new_outs[k] = outs[:, 1, :, :]
|
||||
elif k == "binary_maps":
|
||||
new_outs[k] = outs[:, 2, :, :]
|
||||
else:
|
||||
continue
|
||||
return new_outs
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
|
@ -40,6 +93,8 @@ class DistillationDMLLoss(DMLLoss):
|
|||
if self.key is not None:
|
||||
out1 = out1[self.key]
|
||||
out2 = out2[self.key]
|
||||
|
||||
if self.maps_name is None:
|
||||
loss = super().forward(out1, out2)
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
|
@ -47,6 +102,21 @@ class DistillationDMLLoss(DMLLoss):
|
|||
idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, idx)] = loss
|
||||
else:
|
||||
outs1 = self._slice_out(out1)
|
||||
outs2 = self._slice_out(out2)
|
||||
for _c, k in enumerate(outs1.keys()):
|
||||
loss = super().forward(outs1[k], outs2[k])
|
||||
if isinstance(loss, dict):
|
||||
for key in loss:
|
||||
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
|
||||
0], pair[1], map_name, idx)] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
|
||||
idx)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
|
||||
return loss_dict
|
||||
|
||||
|
||||
|
@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
|
|||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_list=[],
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="db",
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.model_name_list = model_name_list
|
||||
self.name = name
|
||||
self.key = None
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = {}
|
||||
for idx, model_name in enumerate(self.model_name_list):
|
||||
out = predicts[model_name]
|
||||
if self.key is not None:
|
||||
out = out[self.key]
|
||||
loss = super().forward(out, batch)
|
||||
|
||||
if isinstance(loss, dict):
|
||||
for key in loss.keys():
|
||||
if key == "loss":
|
||||
continue
|
||||
name = "{}_{}_{}".format(self.name, model_name, key)
|
||||
loss_dict[name] = loss[key]
|
||||
else:
|
||||
loss_dict["{}_{}".format(self.name, model_name)] = loss
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDilaDBLoss(DBLoss):
|
||||
def __init__(self,
|
||||
model_name_pairs=[],
|
||||
key=None,
|
||||
balance_loss=True,
|
||||
main_loss_type='DiceLoss',
|
||||
alpha=5,
|
||||
beta=10,
|
||||
ohem_ratio=3,
|
||||
eps=1e-6,
|
||||
name="dila_dbloss"):
|
||||
super().__init__()
|
||||
self.model_name_pairs = model_name_pairs
|
||||
self.name = name
|
||||
self.key = key
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
loss_dict = dict()
|
||||
for idx, pair in enumerate(self.model_name_pairs):
|
||||
stu_outs = predicts[pair[0]]
|
||||
tch_outs = predicts[pair[1]]
|
||||
if self.key is not None:
|
||||
stu_preds = stu_outs[self.key]
|
||||
tch_preds = tch_outs[self.key]
|
||||
|
||||
stu_shrink_maps = stu_preds[:, 0, :, :]
|
||||
stu_binary_maps = stu_preds[:, 2, :, :]
|
||||
|
||||
# dilation to teacher prediction
|
||||
dilation_w = np.array([[1, 1], [1, 1]])
|
||||
th_shrink_maps = tch_preds[:, 0, :, :]
|
||||
th_shrink_maps = th_shrink_maps.numpy() > 0.3 # thresh = 0.3
|
||||
dilate_maps = np.zeros_like(th_shrink_maps).astype(np.float32)
|
||||
for i in range(th_shrink_maps.shape[0]):
|
||||
dilate_maps[i] = cv2.dilate(
|
||||
th_shrink_maps[i, :, :].astype(np.uint8), dilation_w)
|
||||
th_shrink_maps = paddle.to_tensor(dilate_maps)
|
||||
|
||||
label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = batch[
|
||||
1:]
|
||||
|
||||
# calculate the shrink map loss
|
||||
bce_loss = self.alpha * self.bce_loss(
|
||||
stu_shrink_maps, th_shrink_maps, label_shrink_mask)
|
||||
loss_binary_maps = self.dice_loss(stu_binary_maps, th_shrink_maps,
|
||||
label_shrink_mask)
|
||||
|
||||
# k = f"{self.name}_{pair[0]}_{pair[1]}"
|
||||
k = "{}_{}_{}".format(self.name, pair[0], pair[1])
|
||||
loss_dict[k] = bce_loss + loss_binary_maps
|
||||
|
||||
loss_dict = _sum_loss(loss_dict)
|
||||
return loss_dict
|
||||
|
||||
|
||||
class DistillationDistanceLoss(DistanceLoss):
|
||||
"""
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
from paddle.nn import functional as F
|
||||
from paddle import fluid
|
||||
|
||||
class TableAttentionLoss(nn.Layer):
|
||||
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
|
||||
super(TableAttentionLoss, self).__init__()
|
||||
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
|
||||
self.structure_weight = structure_weight
|
||||
self.loc_weight = loc_weight
|
||||
self.use_giou = use_giou
|
||||
self.giou_weight = giou_weight
|
||||
|
||||
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
|
||||
'''
|
||||
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:param bbox:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
|
||||
:return: loss
|
||||
'''
|
||||
ix1 = fluid.layers.elementwise_max(preds[:, 0], bbox[:, 0])
|
||||
iy1 = fluid.layers.elementwise_max(preds[:, 1], bbox[:, 1])
|
||||
ix2 = fluid.layers.elementwise_min(preds[:, 2], bbox[:, 2])
|
||||
iy2 = fluid.layers.elementwise_min(preds[:, 3], bbox[:, 3])
|
||||
|
||||
iw = fluid.layers.clip(ix2 - ix1 + 1e-3, 0., 1e10)
|
||||
ih = fluid.layers.clip(iy2 - iy1 + 1e-3, 0., 1e10)
|
||||
|
||||
# overlap
|
||||
inters = iw * ih
|
||||
|
||||
# union
|
||||
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
|
||||
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
|
||||
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
|
||||
|
||||
# ious
|
||||
ious = inters / uni
|
||||
|
||||
ex1 = fluid.layers.elementwise_min(preds[:, 0], bbox[:, 0])
|
||||
ey1 = fluid.layers.elementwise_min(preds[:, 1], bbox[:, 1])
|
||||
ex2 = fluid.layers.elementwise_max(preds[:, 2], bbox[:, 2])
|
||||
ey2 = fluid.layers.elementwise_max(preds[:, 3], bbox[:, 3])
|
||||
ew = fluid.layers.clip(ex2 - ex1 + 1e-3, 0., 1e10)
|
||||
eh = fluid.layers.clip(ey2 - ey1 + 1e-3, 0., 1e10)
|
||||
|
||||
# enclose erea
|
||||
enclose = ew * eh + eps
|
||||
giou = ious - (enclose - uni) / enclose
|
||||
|
||||
loss = 1 - giou
|
||||
|
||||
if reduction == 'mean':
|
||||
loss = paddle.mean(loss)
|
||||
elif reduction == 'sum':
|
||||
loss = paddle.sum(loss)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return loss
|
||||
|
||||
def forward(self, predicts, batch):
|
||||
structure_probs = predicts['structure_probs']
|
||||
structure_targets = batch[1].astype("int64")
|
||||
structure_targets = structure_targets[:, 1:]
|
||||
if len(batch) == 6:
|
||||
structure_mask = batch[5].astype("int64")
|
||||
structure_mask = structure_mask[:, 1:]
|
||||
structure_mask = paddle.reshape(structure_mask, [-1])
|
||||
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
|
||||
structure_targets = paddle.reshape(structure_targets, [-1])
|
||||
structure_loss = self.loss_func(structure_probs, structure_targets)
|
||||
|
||||
if len(batch) == 6:
|
||||
structure_loss = structure_loss * structure_mask
|
||||
|
||||
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
|
||||
structure_loss = paddle.mean(structure_loss) * self.structure_weight
|
||||
|
||||
loc_preds = predicts['loc_preds']
|
||||
loc_targets = batch[2].astype("float32")
|
||||
loc_targets_mask = batch[4].astype("float32")
|
||||
loc_targets = loc_targets[:, 1:, :]
|
||||
loc_targets_mask = loc_targets_mask[:, 1:, :]
|
||||
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
|
||||
if self.use_giou:
|
||||
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
|
||||
total_loss = structure_loss + loc_loss + loc_loss_giou
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
|
||||
else:
|
||||
total_loss = structure_loss + loc_loss
|
||||
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
|
|
@ -26,11 +26,11 @@ from .rec_metric import RecMetric
|
|||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
from .distillation_metric import DistillationMetric
|
||||
|
||||
from .table_metric import TableMetric
|
||||
|
||||
def build_metric(config):
|
||||
support_dict = [
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric"
|
||||
"DetMetric", "RecMetric", "ClsMetric", "E2EMetric", "DistillationMetric", "TableMetric"
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -55,6 +55,7 @@ class DetMetric(object):
|
|||
result = self.evaluator.evaluate_image(gt_info_list, det_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
|
|
|
@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
|
|||
class DistillationMetric(object):
|
||||
def __init__(self,
|
||||
key=None,
|
||||
base_metric_name="RecMetric",
|
||||
main_indicator='acc',
|
||||
base_metric_name=None,
|
||||
main_indicator=None,
|
||||
**kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.key = key
|
||||
|
@ -42,16 +42,13 @@ class DistillationMetric(object):
|
|||
main_indicator=self.main_indicator, **self.kwargs)
|
||||
self.metrics[key].reset()
|
||||
|
||||
def __call__(self, preds, *args, **kwargs):
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
assert isinstance(preds, dict)
|
||||
if self.metrics is None:
|
||||
self._init_metrcis(preds)
|
||||
output = dict()
|
||||
for key in preds:
|
||||
metric = self.metrics[key].__call__(preds[key], *args, **kwargs)
|
||||
for sub_key in metric:
|
||||
output["{}_{}".format(key, sub_key)] = metric[sub_key]
|
||||
return output
|
||||
self.metrics[key].__call__(preds[key], batch, **kwargs)
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
class TableMetric(object):
|
||||
def __init__(self, main_indicator='acc', **kwargs):
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, pred, batch, *args, **kwargs):
|
||||
structure_probs = pred['structure_probs'].numpy()
|
||||
structure_labels = batch[1]
|
||||
correct_num = 0
|
||||
all_num = 0
|
||||
structure_probs = np.argmax(structure_probs, axis=2)
|
||||
structure_labels = structure_labels[:, 1:]
|
||||
batch_size = structure_probs.shape[0]
|
||||
for bno in range(batch_size):
|
||||
all_num += 1
|
||||
if (structure_probs[bno] == structure_labels[bno]).all():
|
||||
correct_num += 1
|
||||
self.correct_num += correct_num
|
||||
self.all_num += all_num
|
||||
return {
|
||||
'acc': correct_num * 1.0 / all_num,
|
||||
}
|
||||
|
||||
def get_metric(self):
|
||||
"""
|
||||
return metrics {
|
||||
'acc': 0,
|
||||
}
|
||||
"""
|
||||
acc = 1.0 * self.correct_num / self.all_num
|
||||
self.reset()
|
||||
return {'acc': acc}
|
||||
|
||||
def reset(self):
|
||||
self.correct_num = 0
|
||||
self.all_num = 0
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -77,10 +77,10 @@ class BaseModel(nn.Layer):
|
|||
if self.use_neck:
|
||||
x = self.neck(x)
|
||||
y["neck_out"] = x
|
||||
if data is None:
|
||||
x = self.head(x)
|
||||
x = self.head(x, targets=data)
|
||||
if isinstance(x, dict):
|
||||
y.update(x)
|
||||
else:
|
||||
x = self.head(x, data)
|
||||
y["head_out"] = x
|
||||
if self.return_all_feats:
|
||||
return y
|
||||
|
|
|
@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
|
|||
from ppocr.modeling.necks import build_neck
|
||||
from ppocr.modeling.heads import build_head
|
||||
from .base_model import BaseModel
|
||||
from ppocr.utils.save_load import init_model
|
||||
from ppocr.utils.save_load import init_model, load_pretrained_params
|
||||
|
||||
__all__ = ['DistillationModel']
|
||||
|
||||
|
@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
|
|||
pretrained = model_config.pop("pretrained")
|
||||
model = BaseModel(model_config)
|
||||
if pretrained is not None:
|
||||
init_model(model, path=pretrained)
|
||||
load_pretrained_params(model, pretrained)
|
||||
if freeze_params:
|
||||
for param in model.parameters():
|
||||
param.trainable = False
|
||||
|
|
|
@ -12,32 +12,36 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['build_backbone']
|
||||
__all__ = ["build_backbone"]
|
||||
|
||||
|
||||
def build_backbone(config, model_type):
|
||||
if model_type == 'det':
|
||||
if model_type == "det":
|
||||
from .det_mobilenet_v3 import MobileNetV3
|
||||
from .det_resnet_vd import ResNet
|
||||
from .det_resnet_vd_sast import ResNet_SAST
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNet_SAST']
|
||||
elif model_type == 'rec' or model_type == 'cls':
|
||||
support_dict = ["MobileNetV3", "ResNet", "ResNet_SAST"]
|
||||
elif model_type == "rec" or model_type == "cls":
|
||||
from .rec_mobilenet_v3 import MobileNetV3
|
||||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
from .rec_mv1_enhance import MobileNetV1Enhance
|
||||
from .rec_nrtr_mtb import MTB
|
||||
from .rec_swin import SwinTransformer
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
|
||||
|
||||
elif model_type == 'e2e':
|
||||
support_dict = ['MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB', 'SwinTransformer']
|
||||
elif model_type == "e2e":
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
support_dict = ["ResNet"]
|
||||
elif model_type == "table":
|
||||
from .table_resnet_vd import ResNet
|
||||
from .table_mobilenet_v3 import MobileNetV3
|
||||
support_dict = ["ResNet", "MobileNetV3"]
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
module_name = config.pop('name')
|
||||
module_name = config.pop("name")
|
||||
assert module_name in support_dict, Exception(
|
||||
'when model typs is {}, backbone only support {}'.format(model_type,
|
||||
"when model typs is {}, backbone only support {}".format(model_type,
|
||||
support_dict))
|
||||
module_class = eval(module_name)(**config)
|
||||
return module_class
|
||||
|
|
|
@ -0,0 +1,256 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import ParamAttr, reshape, transpose, concat, split
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
|
||||
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
|
||||
from paddle.nn.initializer import KaimingNormal
|
||||
import math
|
||||
from paddle.nn.functional import hardswish, hardsigmoid
|
||||
from paddle.regularizer import L2Decay
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='hard_swish'):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self._conv = Conv2D(
|
||||
in_channels=num_channels,
|
||||
out_channels=num_filters,
|
||||
kernel_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
weight_attr=ParamAttr(initializer=KaimingNormal()),
|
||||
bias_attr=False)
|
||||
|
||||
self._batch_norm = BatchNorm(
|
||||
num_filters,
|
||||
act=act,
|
||||
param_attr=ParamAttr(regularizer=L2Decay(0.0)),
|
||||
bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DepthwiseSeparable(nn.Layer):
|
||||
def __init__(self,
|
||||
num_channels,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
dw_size=3,
|
||||
padding=1,
|
||||
use_se=False):
|
||||
super(DepthwiseSeparable, self).__init__()
|
||||
self.use_se = use_se
|
||||
self._depthwise_conv = ConvBNLayer(
|
||||
num_channels=num_channels,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
filter_size=dw_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
num_groups=int(num_groups * scale))
|
||||
if use_se:
|
||||
self._se = SEModule(int(num_filters1 * scale))
|
||||
self._pointwise_conv = ConvBNLayer(
|
||||
num_channels=int(num_filters1 * scale),
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._depthwise_conv(inputs)
|
||||
if self.use_se:
|
||||
y = self._se(y)
|
||||
y = self._pointwise_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class MobileNetV1Enhance(nn.Layer):
|
||||
def __init__(self, in_channels=3, scale=0.5, **kwargs):
|
||||
super().__init__()
|
||||
self.scale = scale
|
||||
self.block_list = []
|
||||
|
||||
self.conv1 = ConvBNLayer(
|
||||
num_channels=3,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
conv2_1 = DepthwiseSeparable(
|
||||
num_channels=int(32 * scale),
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_1)
|
||||
|
||||
conv2_2 = DepthwiseSeparable(
|
||||
num_channels=int(64 * scale),
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv2_2)
|
||||
|
||||
conv3_1 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_1)
|
||||
|
||||
conv3_2 = DepthwiseSeparable(
|
||||
num_channels=int(128 * scale),
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv3_2)
|
||||
|
||||
conv4_1 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_1)
|
||||
|
||||
conv4_2 = DepthwiseSeparable(
|
||||
num_channels=int(256 * scale),
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=(2, 1),
|
||||
scale=scale)
|
||||
self.block_list.append(conv4_2)
|
||||
|
||||
for _ in range(5):
|
||||
conv5 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=False)
|
||||
self.block_list.append(conv5)
|
||||
|
||||
conv5_6 = DepthwiseSeparable(
|
||||
num_channels=int(512 * scale),
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=(2, 1),
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
scale=scale,
|
||||
use_se=True)
|
||||
self.block_list.append(conv5_6)
|
||||
|
||||
conv6 = DepthwiseSeparable(
|
||||
num_channels=int(1024 * scale),
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=1,
|
||||
dw_size=5,
|
||||
padding=2,
|
||||
use_se=True,
|
||||
scale=scale)
|
||||
self.block_list.append(conv6)
|
||||
|
||||
self.block_list = nn.Sequential(*self.block_list)
|
||||
|
||||
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
|
||||
self.out_channels = int(1024 * scale)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1(inputs)
|
||||
y = self.block_list(y)
|
||||
y = self.pool(y)
|
||||
return y
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, channel, reduction=4):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2D(1)
|
||||
self.conv1 = Conv2D(
|
||||
in_channels=channel,
|
||||
out_channels=channel // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
self.conv2 = Conv2D(
|
||||
in_channels=channel // reduction,
|
||||
out_channels=channel,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(),
|
||||
bias_attr=ParamAttr())
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = hardsigmoid(outputs)
|
||||
return paddle.multiply(x=inputs, y=outputs)
|
|
@ -0,0 +1,287 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
__all__ = ['MobileNetV3']
|
||||
|
||||
|
||||
def make_divisible(v, divisor=8, min_value=None):
|
||||
if min_value is None:
|
||||
min_value = divisor
|
||||
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
|
||||
if new_v < 0.9 * v:
|
||||
new_v += divisor
|
||||
return new_v
|
||||
|
||||
|
||||
class MobileNetV3(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
model_name='large',
|
||||
scale=0.5,
|
||||
disable_se=False,
|
||||
**kwargs):
|
||||
"""
|
||||
the MobilenetV3 backbone network for detection module.
|
||||
Args:
|
||||
params(dict): the super parameters for build network
|
||||
"""
|
||||
super(MobileNetV3, self).__init__()
|
||||
|
||||
self.disable_se = disable_se
|
||||
|
||||
if model_name == "large":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, False, 'relu', 1],
|
||||
[3, 64, 24, False, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 1],
|
||||
[5, 72, 40, True, 'relu', 2],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[5, 120, 40, True, 'relu', 1],
|
||||
[3, 240, 80, False, 'hardswish', 2],
|
||||
[3, 200, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 184, 80, False, 'hardswish', 1],
|
||||
[3, 480, 112, True, 'hardswish', 1],
|
||||
[3, 672, 112, True, 'hardswish', 1],
|
||||
[5, 672, 160, True, 'hardswish', 2],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
[5, 960, 160, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 960
|
||||
elif model_name == "small":
|
||||
cfg = [
|
||||
# k, exp, c, se, nl, s,
|
||||
[3, 16, 16, True, 'relu', 2],
|
||||
[3, 72, 24, False, 'relu', 2],
|
||||
[3, 88, 24, False, 'relu', 1],
|
||||
[5, 96, 40, True, 'hardswish', 2],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 240, 40, True, 'hardswish', 1],
|
||||
[5, 120, 48, True, 'hardswish', 1],
|
||||
[5, 144, 48, True, 'hardswish', 1],
|
||||
[5, 288, 96, True, 'hardswish', 2],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
[5, 576, 96, True, 'hardswish', 1],
|
||||
]
|
||||
cls_ch_squeeze = 576
|
||||
else:
|
||||
raise NotImplementedError("mode[" + model_name +
|
||||
"_model] is not implemented!")
|
||||
|
||||
supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
|
||||
assert scale in supported_scale, \
|
||||
"supported scale are {} but input scale is {}".format(supported_scale, scale)
|
||||
inplanes = 16
|
||||
# conv1
|
||||
self.conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=make_divisible(inplanes * scale),
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv1')
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
block_list = []
|
||||
i = 0
|
||||
inplanes = make_divisible(inplanes * scale)
|
||||
for (k, exp, c, se, nl, s) in cfg:
|
||||
se = se and not self.disable_se
|
||||
start_idx = 2 if model_name == 'large' else 0
|
||||
if s == 2 and i > start_idx:
|
||||
self.out_channels.append(inplanes)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
block_list = []
|
||||
block_list.append(
|
||||
ResidualUnit(
|
||||
in_channels=inplanes,
|
||||
mid_channels=make_divisible(scale * exp),
|
||||
out_channels=make_divisible(scale * c),
|
||||
kernel_size=k,
|
||||
stride=s,
|
||||
use_se=se,
|
||||
act=nl,
|
||||
name="conv" + str(i + 2)))
|
||||
inplanes = make_divisible(scale * c)
|
||||
i += 1
|
||||
block_list.append(
|
||||
ConvBNLayer(
|
||||
in_channels=inplanes,
|
||||
out_channels=make_divisible(scale * cls_ch_squeeze),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act='hardswish',
|
||||
name='conv_last'))
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
|
||||
for i, stage in enumerate(self.stages):
|
||||
self.add_sublayer(sublayer=stage, name="stage{}".format(i))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
out_list = []
|
||||
for stage in self.stages:
|
||||
x = stage(x)
|
||||
out_list.append(x)
|
||||
return out_list
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=None,
|
||||
param_attr=ParamAttr(name=name + "_bn_scale"),
|
||||
bias_attr=ParamAttr(name=name + "_bn_offset"),
|
||||
moving_mean_name=name + "_bn_mean",
|
||||
moving_variance_name=name + "_bn_variance")
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
if self.if_act:
|
||||
if self.act == "relu":
|
||||
x = F.relu(x)
|
||||
elif self.act == "hardswish":
|
||||
x = F.hardswish(x)
|
||||
else:
|
||||
print("The activation function({}) is selected incorrectly.".
|
||||
format(self.act))
|
||||
exit()
|
||||
return x
|
||||
|
||||
|
||||
class ResidualUnit(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
use_se,
|
||||
act=None,
|
||||
name=''):
|
||||
super(ResidualUnit, self).__init__()
|
||||
self.if_shortcut = stride == 1 and in_channels == out_channels
|
||||
self.if_se = use_se
|
||||
|
||||
self.expand_conv = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_expand")
|
||||
self.bottleneck_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=int((kernel_size - 1) // 2),
|
||||
groups=mid_channels,
|
||||
if_act=True,
|
||||
act=act,
|
||||
name=name + "_depthwise")
|
||||
if self.if_se:
|
||||
self.mid_se = SEModule(mid_channels, name=name + "_se")
|
||||
self.linear_conv = ConvBNLayer(
|
||||
in_channels=mid_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
if_act=False,
|
||||
act=None,
|
||||
name=name + "_linear")
|
||||
|
||||
def forward(self, inputs):
|
||||
x = self.expand_conv(inputs)
|
||||
x = self.bottleneck_conv(x)
|
||||
if self.if_se:
|
||||
x = self.mid_se(x)
|
||||
x = self.linear_conv(x)
|
||||
if self.if_shortcut:
|
||||
x = paddle.add(inputs, x)
|
||||
return x
|
||||
|
||||
|
||||
class SEModule(nn.Layer):
|
||||
def __init__(self, in_channels, reduction=4, name=""):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2D(1)
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels // reduction,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name=name + "_1_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_1_offset"))
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=in_channels // reduction,
|
||||
out_channels=in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
weight_attr=ParamAttr(name + "_2_weights"),
|
||||
bias_attr=ParamAttr(name=name + "_2_offset"))
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.avg_pool(inputs)
|
||||
outputs = self.conv1(outputs)
|
||||
outputs = F.relu(outputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
|
||||
return inputs * outputs
|
|
@ -0,0 +1,280 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
if self.is_vd_mode:
|
||||
inputs = self._pool2d_avg(inputs)
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
depth = [3, 4, 6, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512,
|
||||
1024] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.conv1_2 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_2")
|
||||
self.conv1_3 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv1_3")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = []
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv1_1(inputs)
|
||||
y = self.conv1_2(y)
|
||||
y = self.conv1_3(y)
|
||||
y = self.pool2d_max(y)
|
||||
out = []
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -32,8 +32,12 @@ def build_head(config):
|
|||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead', 'PGHead', 'TransformerOptim']
|
||||
|
||||
'SRNHead', 'PGHead', 'TransformerOptim', 'TableAttentionHead']
|
||||
|
||||
|
||||
#table head
|
||||
from .table_att_head import TableAttentionHead
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -43,7 +43,7 @@ class ClsHead(nn.Layer):
|
|||
initializer=nn.initializer.Uniform(-stdv, stdv)),
|
||||
bias_attr=ParamAttr(name="fc_0.b_0"), )
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
x = self.pool(x)
|
||||
x = paddle.reshape(x, shape=[x.shape[0], x.shape[1]])
|
||||
x = self.fc(x)
|
||||
|
|
|
@ -106,7 +106,7 @@ class DBHead(nn.Layer):
|
|||
def step_function(self, x, y):
|
||||
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
shrink_maps = self.binarize(x)
|
||||
if not self.training:
|
||||
return {'maps': shrink_maps}
|
||||
|
|
|
@ -109,7 +109,7 @@ class EASTHead(nn.Layer):
|
|||
act=None,
|
||||
name="f_geo")
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_det = self.det_conv1(x)
|
||||
f_det = self.det_conv2(f_det)
|
||||
f_score = self.score_conv(f_det)
|
||||
|
|
|
@ -116,7 +116,7 @@ class SASTHead(nn.Layer):
|
|||
self.head1 = SAST_Header1(in_channels)
|
||||
self.head2 = SAST_Header2(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score, f_border = self.head1(x)
|
||||
f_tvo, f_tco = self.head2(x)
|
||||
|
||||
|
|
|
@ -220,7 +220,7 @@ class PGHead(nn.Layer):
|
|||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, targets=None):
|
||||
f_score = self.conv_f_score1(x)
|
||||
f_score = self.conv_f_score2(f_score)
|
||||
f_score = self.conv_f_score3(f_score)
|
||||
|
|
|
@ -33,8 +33,14 @@ def get_para_bias_attr(l2_decay, k):
|
|||
|
||||
|
||||
class CTCHead(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, fc_decay=0.0004, **kwargs):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
fc_decay=0.0004,
|
||||
mid_channels=None,
|
||||
**kwargs):
|
||||
super(CTCHead, self).__init__()
|
||||
if mid_channels is None:
|
||||
weight_attr, bias_attr = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc = nn.Linear(
|
||||
|
@ -42,10 +48,32 @@ class CTCHead(nn.Layer):
|
|||
out_channels,
|
||||
weight_attr=weight_attr,
|
||||
bias_attr=bias_attr)
|
||||
self.out_channels = out_channels
|
||||
else:
|
||||
weight_attr1, bias_attr1 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=in_channels)
|
||||
self.fc1 = nn.Linear(
|
||||
in_channels,
|
||||
mid_channels,
|
||||
weight_attr=weight_attr1,
|
||||
bias_attr=bias_attr1)
|
||||
|
||||
def forward(self, x, labels=None):
|
||||
weight_attr2, bias_attr2 = get_para_bias_attr(
|
||||
l2_decay=fc_decay, k=mid_channels)
|
||||
self.fc2 = nn.Linear(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
weight_attr=weight_attr2,
|
||||
bias_attr=bias_attr2)
|
||||
self.out_channels = out_channels
|
||||
self.mid_channels = mid_channels
|
||||
|
||||
def forward(self, x, targets=None):
|
||||
if self.mid_channels is None:
|
||||
predicts = self.fc(x)
|
||||
else:
|
||||
predicts = self.fc1(x)
|
||||
predicts = self.fc2(predicts)
|
||||
|
||||
if not self.training:
|
||||
predicts = F.softmax(predicts, axis=2)
|
||||
return predicts
|
||||
|
|
|
@ -250,7 +250,8 @@ class SRNHead(nn.Layer):
|
|||
|
||||
self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
|
||||
|
||||
def forward(self, inputs, others):
|
||||
def forward(self, inputs, targets=None):
|
||||
others = targets[-4:]
|
||||
encoder_word_pos = others[0]
|
||||
gsrm_word_pos = others[1]
|
||||
gsrm_slf_attn_bias1 = others[2]
|
||||
|
|
|
@ -0,0 +1,238 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TableAttentionHead(nn.Layer):
|
||||
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
|
||||
super(TableAttentionHead, self).__init__()
|
||||
self.input_size = in_channels[-1]
|
||||
self.hidden_size = hidden_size
|
||||
self.elem_num = 30
|
||||
self.max_text_length = 100
|
||||
self.max_elem_length = 500
|
||||
self.max_cell_num = 500
|
||||
|
||||
self.structure_attention_cell = AttentionGRUCell(
|
||||
self.input_size, hidden_size, self.elem_num, use_gru=False)
|
||||
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
|
||||
self.loc_type = loc_type
|
||||
self.in_max_len = in_max_len
|
||||
|
||||
if self.loc_type == 1:
|
||||
self.loc_generator = nn.Linear(hidden_size, 4)
|
||||
else:
|
||||
if self.in_max_len == 640:
|
||||
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
|
||||
elif self.in_max_len == 800:
|
||||
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
|
||||
else:
|
||||
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
|
||||
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None):
|
||||
# if and else branch are both needed when you want to assign a variable
|
||||
# if you modify the var in just one branch, then the modification will not work.
|
||||
fea = inputs[-1]
|
||||
if len(fea.shape) == 3:
|
||||
pass
|
||||
else:
|
||||
last_shape = int(np.prod(fea.shape[2:])) # gry added
|
||||
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
|
||||
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
|
||||
batch_size = fea.shape[0]
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
output_hiddens = []
|
||||
if self.training and targets is not None:
|
||||
structure = targets[0]
|
||||
for i in range(self.max_elem_length+1):
|
||||
elem_onehots = self._char_to_onehot(
|
||||
structure[:, i], onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
structure_probs = None
|
||||
loc_preds = None
|
||||
elem_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
max_elem_length = paddle.to_tensor(self.max_elem_length)
|
||||
i = 0
|
||||
while i < max_elem_length+1:
|
||||
elem_onehots = self._char_to_onehot(
|
||||
temp_elem, onehot_dim=self.elem_num)
|
||||
(outputs, hidden), alpha = self.structure_attention_cell(
|
||||
hidden, fea, elem_onehots)
|
||||
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
|
||||
structure_probs_step = self.structure_generator(outputs)
|
||||
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
|
||||
i += 1
|
||||
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
structure_probs = self.structure_generator(output)
|
||||
structure_probs = F.softmax(structure_probs)
|
||||
if self.loc_type == 1:
|
||||
loc_preds = self.loc_generator(output)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
else:
|
||||
loc_fea = fea.transpose([0, 2, 1])
|
||||
loc_fea = self.loc_fea_trans(loc_fea)
|
||||
loc_fea = loc_fea.transpose([0, 2, 1])
|
||||
loc_concat = paddle.concat([output, loc_fea], axis=2)
|
||||
loc_preds = self.loc_generator(loc_concat)
|
||||
loc_preds = F.sigmoid(loc_preds)
|
||||
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
|
||||
|
||||
|
||||
class AttentionGRUCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionGRUCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
return cur_hidden, alpha
|
||||
|
||||
|
||||
class AttentionLSTM(nn.Layer):
|
||||
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
|
||||
super(AttentionLSTM, self).__init__()
|
||||
self.input_size = in_channels
|
||||
self.hidden_size = hidden_size
|
||||
self.num_classes = out_channels
|
||||
|
||||
self.attention_cell = AttentionLSTMCell(
|
||||
in_channels, hidden_size, out_channels, use_gru=False)
|
||||
self.generator = nn.Linear(hidden_size, out_channels)
|
||||
|
||||
def _char_to_onehot(self, input_char, onehot_dim):
|
||||
input_ont_hot = F.one_hot(input_char, onehot_dim)
|
||||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = inputs.shape[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
|
||||
(batch_size, self.hidden_size)))
|
||||
output_hiddens = []
|
||||
|
||||
if targets is not None:
|
||||
for i in range(num_steps):
|
||||
# one-hot vectors for a i-th char
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets[:, i], onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
|
||||
output = paddle.concat(output_hiddens, axis=1)
|
||||
probs = self.generator(output)
|
||||
|
||||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
targets, onehot_dim=self.num_classes)
|
||||
hidden, alpha = self.attention_cell(hidden, inputs,
|
||||
char_onehots)
|
||||
probs_step = self.generator(hidden[0])
|
||||
hidden = (hidden[1][0], hidden[1][1])
|
||||
if probs is None:
|
||||
probs = paddle.unsqueeze(probs_step, axis=1)
|
||||
else:
|
||||
probs = paddle.concat(
|
||||
[probs, paddle.unsqueeze(
|
||||
probs_step, axis=1)], axis=1)
|
||||
|
||||
next_input = probs_step.argmax(axis=1)
|
||||
|
||||
targets = next_input
|
||||
|
||||
return probs
|
||||
|
||||
|
||||
class AttentionLSTMCell(nn.Layer):
|
||||
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
|
||||
super(AttentionLSTMCell, self).__init__()
|
||||
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
|
||||
self.h2h = nn.Linear(hidden_size, hidden_size)
|
||||
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
|
||||
if not use_gru:
|
||||
self.rnn = nn.LSTMCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
else:
|
||||
self.rnn = nn.GRUCell(
|
||||
input_size=input_size + num_embeddings, hidden_size=hidden_size)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
def forward(self, prev_hidden, batch_H, char_onehots):
|
||||
batch_H_proj = self.i2h(batch_H)
|
||||
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
|
||||
res = paddle.add(batch_H_proj, prev_hidden_proj)
|
||||
res = paddle.tanh(res)
|
||||
e = self.score(res)
|
||||
|
||||
alpha = F.softmax(e, axis=1)
|
||||
alpha = paddle.transpose(alpha, [0, 2, 1])
|
||||
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
|
||||
concat_context = paddle.concat([context, char_onehots], 1)
|
||||
cur_hidden = self.rnn(concat_context, prev_hidden)
|
||||
|
||||
return cur_hidden, alpha
|
|
@ -21,7 +21,8 @@ def build_neck(config):
|
|||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
from .table_fpn import TableFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|