Merge branch 'android_demo' of https://github.com/WenmuZhou/PaddleOCR into android_demo
|
@ -206,7 +206,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.labelList = EditInList()
|
||||
labelListContainer = QWidget()
|
||||
labelListContainer.setLayout(listLayout)
|
||||
self.labelList.itemActivated.connect(self.labelSelectionChanged)
|
||||
#self.labelList.itemActivated.connect(self.labelSelectionChanged)
|
||||
self.labelList.itemSelectionChanged.connect(self.labelSelectionChanged)
|
||||
self.labelList.clicked.connect(self.labelList.item_clicked)
|
||||
# Connect to itemChanged to detect checkbox changes.
|
||||
|
@ -219,7 +219,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
################## detection box ####################
|
||||
self.BoxList = QListWidget()
|
||||
|
||||
self.BoxList.itemActivated.connect(self.boxSelectionChanged)
|
||||
#self.BoxList.itemActivated.connect(self.boxSelectionChanged)
|
||||
self.BoxList.itemSelectionChanged.connect(self.boxSelectionChanged)
|
||||
self.BoxList.itemDoubleClicked.connect(self.editBox)
|
||||
# Connect to itemChanged to detect checkbox changes.
|
||||
|
@ -435,7 +435,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
|
||||
######## New actions #######
|
||||
AutoRec = action(getStr('autoRecognition'), self.autoRecognition,
|
||||
'Ctrl+Shift+A', 'Auto', getStr('autoRecognition'), enabled=False)
|
||||
'', 'Auto', getStr('autoRecognition'), enabled=False)
|
||||
|
||||
reRec = action(getStr('reRecognition'), self.reRecognition,
|
||||
'Ctrl+Shift+R', 'reRec', getStr('reRecognition'), enabled=False)
|
||||
|
@ -444,7 +444,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
'Ctrl+R', 'reRec', getStr('singleRe'), enabled=False)
|
||||
|
||||
createpoly = action(getStr('creatPolygon'), self.createPolygon,
|
||||
'q', 'new', 'Creat Polygon', enabled=True)
|
||||
'q', 'new', getStr('creatPolygon'), enabled=True)
|
||||
|
||||
saveRec = action(getStr('saveRec'), self.saveRecResult,
|
||||
'', 'save', getStr('saveRec'), enabled=False)
|
||||
|
@ -452,6 +452,12 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
saveLabel = action(getStr('saveLabel'), self.saveLabelFile, #
|
||||
'Ctrl+S', 'save', getStr('saveLabel'), enabled=False)
|
||||
|
||||
undoLastPoint = action(getStr("undoLastPoint"), self.canvas.undoLastPoint,
|
||||
'Ctrl+Z', "undo", getStr("undoLastPoint"), enabled=False)
|
||||
|
||||
undo = action(getStr("undo"), self.undoShapeEdit,
|
||||
'Ctrl+Z', "undo", getStr("undo"), enabled=False)
|
||||
|
||||
self.editButton.setDefaultAction(edit)
|
||||
self.newButton.setDefaultAction(create)
|
||||
self.DelButton.setDefaultAction(deleteImg)
|
||||
|
@ -512,10 +518,11 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
zoom=zoom, zoomIn=zoomIn, zoomOut=zoomOut, zoomOrg=zoomOrg,
|
||||
fitWindow=fitWindow, fitWidth=fitWidth,
|
||||
zoomActions=zoomActions, saveLabel=saveLabel,
|
||||
undo=undo, undoLastPoint=undoLastPoint,
|
||||
fileMenuActions=(
|
||||
opendir, saveLabel, resetAll, quit),
|
||||
beginner=(), advanced=(),
|
||||
editMenu=(createpoly, edit, copy, delete,singleRere,
|
||||
editMenu=(createpoly, edit, copy, delete,singleRere,None, undo, undoLastPoint,
|
||||
None, color1, self.drawSquaresOption),
|
||||
beginnerContext=(create, edit, copy, delete, singleRere),
|
||||
advancedContext=(createMode, editMode, edit, copy,
|
||||
|
@ -549,8 +556,13 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.labelDialogOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
|
||||
self.labelDialogOption.triggered.connect(self.speedChoose)
|
||||
|
||||
self.autoSaveOption = QAction(getStr('autoSaveMode'), self)
|
||||
self.autoSaveOption.setCheckable(True)
|
||||
self.autoSaveOption.setChecked(settings.get(SETTING_PAINT_LABEL, False))
|
||||
self.autoSaveOption.triggered.connect(self.autoSaveFunc)
|
||||
|
||||
addActions(self.menus.file,
|
||||
(opendir, None, saveLabel, saveRec, None, resetAll, deleteImg, quit))
|
||||
(opendir, None, saveLabel, saveRec, self.autoSaveOption, None, resetAll, deleteImg, quit))
|
||||
|
||||
addActions(self.menus.help, (showSteps, showInfo))
|
||||
addActions(self.menus.view, (
|
||||
|
@ -566,9 +578,9 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
|
||||
# Custom context menu for the canvas widget:
|
||||
addActions(self.canvas.menus[0], self.actions.beginnerContext)
|
||||
addActions(self.canvas.menus[1], (
|
||||
action('&Copy here', self.copyShape),
|
||||
action('&Move here', self.moveShape)))
|
||||
#addActions(self.canvas.menus[1], (
|
||||
# action('&Copy here', self.copyShape),
|
||||
# action('&Move here', self.moveShape)))
|
||||
|
||||
|
||||
self.statusBar().showMessage('%s started.' % __appname__)
|
||||
|
@ -758,6 +770,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.canvas.setEditing(False)
|
||||
self.canvas.fourpoint = True
|
||||
self.actions.create.setEnabled(False)
|
||||
self.actions.undoLastPoint.setEnabled(True)
|
||||
|
||||
def toggleDrawingSensitive(self, drawing=True):
|
||||
"""In the middle of drawing, toggling between modes should be disabled."""
|
||||
|
@ -866,10 +879,11 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.updateComboBox()
|
||||
|
||||
def updateBoxlist(self):
|
||||
shape = self.canvas.selectedShape
|
||||
item = self.shapesToItemsbox[shape] # listitem
|
||||
text = [(int(p.x()), int(p.y())) for p in shape.points]
|
||||
item.setText(str(text))
|
||||
for shape in self.canvas.selectedShapes+[self.canvas.hShape]:
|
||||
item = self.shapesToItemsbox[shape] # listitem
|
||||
text = [(int(p.x()), int(p.y())) for p in shape.points]
|
||||
item.setText(str(text))
|
||||
self.actions.undo.setEnabled(True)
|
||||
self.setDirty()
|
||||
|
||||
def indexTo5Files(self, currIndex):
|
||||
|
@ -902,23 +916,27 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
if len(self.mImgList) > 0:
|
||||
self.zoomWidget.setValue(self.zoomWidgetValue + self.imgsplider.value())
|
||||
|
||||
# React to canvas signals.
|
||||
def shapeSelectionChanged(self, selected=False):
|
||||
if self._noSelectionSlot:
|
||||
self._noSelectionSlot = False
|
||||
else:
|
||||
shape = self.canvas.selectedShape
|
||||
if shape:
|
||||
self.shapesToItems[shape].setSelected(True)
|
||||
self.shapesToItemsbox[shape].setSelected(True) # ADD
|
||||
else:
|
||||
self.labelList.clearSelection()
|
||||
self.actions.delete.setEnabled(selected)
|
||||
self.actions.copy.setEnabled(selected)
|
||||
self.actions.edit.setEnabled(selected)
|
||||
self.actions.shapeLineColor.setEnabled(selected)
|
||||
self.actions.shapeFillColor.setEnabled(selected)
|
||||
self.actions.singleRere.setEnabled(selected)
|
||||
|
||||
def shapeSelectionChanged(self, selected_shapes):
|
||||
self._noSelectionSlot = True
|
||||
for shape in self.canvas.selectedShapes:
|
||||
shape.selected = False
|
||||
self.labelList.clearSelection()
|
||||
self.canvas.selectedShapes = selected_shapes
|
||||
for shape in self.canvas.selectedShapes:
|
||||
shape.selected = True
|
||||
self.shapesToItems[shape].setSelected(True)
|
||||
self.shapesToItemsbox[shape].setSelected(True)
|
||||
|
||||
self.labelList.scrollToItem(self.currentItem()) # QAbstractItemView.EnsureVisible
|
||||
self.BoxList.scrollToItem(self.currentBox())
|
||||
|
||||
self._noSelectionSlot = False
|
||||
n_selected = len(selected_shapes)
|
||||
self.actions.singleRere.setEnabled(n_selected)
|
||||
self.actions.delete.setEnabled(n_selected)
|
||||
self.actions.copy.setEnabled(n_selected)
|
||||
self.actions.edit.setEnabled(n_selected == 1)
|
||||
|
||||
def addLabel(self, shape):
|
||||
shape.paintLabel = self.displayLabelOption.isChecked()
|
||||
|
@ -941,22 +959,23 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
action.setEnabled(True)
|
||||
self.updateComboBox()
|
||||
|
||||
def remLabel(self, shape):
|
||||
if shape is None:
|
||||
def remLabels(self, shapes):
|
||||
if shapes is None:
|
||||
# print('rm empty label')
|
||||
return
|
||||
item = self.shapesToItems[shape]
|
||||
self.labelList.takeItem(self.labelList.row(item))
|
||||
del self.shapesToItems[shape]
|
||||
del self.itemsToShapes[item]
|
||||
self.updateComboBox()
|
||||
for shape in shapes:
|
||||
item = self.shapesToItems[shape]
|
||||
self.labelList.takeItem(self.labelList.row(item))
|
||||
del self.shapesToItems[shape]
|
||||
del self.itemsToShapes[item]
|
||||
self.updateComboBox()
|
||||
|
||||
# ADD:
|
||||
item = self.shapesToItemsbox[shape]
|
||||
self.BoxList.takeItem(self.BoxList.row(item))
|
||||
del self.shapesToItemsbox[shape]
|
||||
del self.itemsToShapesbox[item]
|
||||
self.updateComboBox()
|
||||
# ADD:
|
||||
item = self.shapesToItemsbox[shape]
|
||||
self.BoxList.takeItem(self.BoxList.row(item))
|
||||
del self.shapesToItemsbox[shape]
|
||||
del self.itemsToShapesbox[item]
|
||||
self.updateComboBox()
|
||||
|
||||
def loadLabels(self, shapes):
|
||||
s = []
|
||||
|
@ -1001,7 +1020,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
|
||||
self.updateComboBox()
|
||||
|
||||
def updateComboBox(self):
|
||||
def updateComboBox(self): # TODO:貌似没用
|
||||
# Get the unique labels and add them to the Combobox.
|
||||
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
|
||||
|
||||
|
@ -1054,26 +1073,38 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
return False
|
||||
|
||||
def copySelectedShape(self):
|
||||
self.addLabel(self.canvas.copySelectedShape())
|
||||
for shape in self.canvas.copySelectedShape():
|
||||
self.addLabel(shape)
|
||||
# fix copy and delete
|
||||
self.shapeSelectionChanged(True)
|
||||
#self.shapeSelectionChanged(True)
|
||||
|
||||
|
||||
def labelSelectionChanged(self):
|
||||
item = self.currentItem()
|
||||
self.labelList.scrollToItem(item, QAbstractItemView.EnsureVisible)
|
||||
if item and self.canvas.editing():
|
||||
self._noSelectionSlot = True
|
||||
self.canvas.selectShape(self.itemsToShapes[item])
|
||||
shape = self.itemsToShapes[item]
|
||||
if self._noSelectionSlot:
|
||||
return
|
||||
if self.canvas.editing():
|
||||
selected_shapes = []
|
||||
for item in self.labelList.selectedItems():
|
||||
selected_shapes.append(self.itemsToShapes[item])
|
||||
if selected_shapes:
|
||||
self.canvas.selectShapes(selected_shapes)
|
||||
else:
|
||||
self.canvas.deSelectShape()
|
||||
|
||||
|
||||
def boxSelectionChanged(self):
|
||||
item = self.currentBox()
|
||||
self.BoxList.scrollToItem(item, QAbstractItemView.EnsureVisible)
|
||||
if item and self.canvas.editing():
|
||||
self._noSelectionSlot = True
|
||||
self.canvas.selectShape(self.itemsToShapesbox[item])
|
||||
shape = self.itemsToShapesbox[item]
|
||||
if self._noSelectionSlot:
|
||||
#self.BoxList.scrollToItem(self.currentBox(), QAbstractItemView.PositionAtCenter)
|
||||
return
|
||||
if self.canvas.editing():
|
||||
selected_shapes = []
|
||||
for item in self.BoxList.selectedItems():
|
||||
selected_shapes.append(self.itemsToShapesbox[item])
|
||||
if selected_shapes:
|
||||
self.canvas.selectShapes(selected_shapes)
|
||||
else:
|
||||
self.canvas.deSelectShape()
|
||||
|
||||
|
||||
def labelItemChanged(self, item):
|
||||
shape = self.itemsToShapes[item]
|
||||
|
@ -1113,6 +1144,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
if self.beginner(): # Switch to edit mode.
|
||||
self.canvas.setEditing(True)
|
||||
self.actions.create.setEnabled(True)
|
||||
self.actions.undoLastPoint.setEnabled(False)
|
||||
self.actions.undo.setEnabled(True)
|
||||
else:
|
||||
self.actions.editMode.setEnabled(True)
|
||||
self.setDirty()
|
||||
|
@ -1548,6 +1581,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.fileListWidget.insertItem(int(currIndex), item)
|
||||
self.openNextImg()
|
||||
self.actions.saveRec.setEnabled(True)
|
||||
self.actions.saveLabel.setEnabled(True)
|
||||
|
||||
elif mode == 'Auto':
|
||||
if annotationFilePath and self.saveLabels(annotationFilePath, mode=mode):
|
||||
|
@ -1643,7 +1677,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.setDirty()
|
||||
|
||||
def deleteSelectedShape(self):
|
||||
self.remLabel(self.canvas.deleteSelected())
|
||||
self.remLabels(self.canvas.deleteSelected())
|
||||
self.actions.undo.setEnabled(True)
|
||||
self.setDirty()
|
||||
if self.noShapes():
|
||||
for action in self.actions.onShapesPresent:
|
||||
|
@ -1653,7 +1688,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
color = self.colorDialog.getColor(self.lineColor, u'Choose line color',
|
||||
default=DEFAULT_LINE_COLOR)
|
||||
if color:
|
||||
self.canvas.selectedShape.line_color = color
|
||||
for shape in self.canvas.selectedShapes: shape.line_color = color
|
||||
self.canvas.update()
|
||||
self.setDirty()
|
||||
|
||||
|
@ -1661,7 +1696,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
color = self.colorDialog.getColor(self.fillColor, u'Choose fill color',
|
||||
default=DEFAULT_FILL_COLOR)
|
||||
if color:
|
||||
self.canvas.selectedShape.fill_color = color
|
||||
for shape in self.canvas.selectedShapes: shape.fill_color = color
|
||||
self.canvas.update()
|
||||
self.setDirty()
|
||||
|
||||
|
@ -1785,25 +1820,25 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
|
||||
def singleRerecognition(self):
|
||||
img = cv2.imread(self.filePath)
|
||||
shape = self.canvas.selectedShape
|
||||
box = [[int(p.x()), int(p.y())] for p in shape.points]
|
||||
assert len(box) == 4
|
||||
img_crop = get_rotate_crop_image(img, np.array(box, np.float32))
|
||||
if img_crop is None:
|
||||
msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
|
||||
QMessageBox.information(self, "Information", msg)
|
||||
return
|
||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||
if result[0][0] != '':
|
||||
result.insert(0, box)
|
||||
print('result in reRec is ', result)
|
||||
if result[1][0] == shape.label:
|
||||
print('label no change')
|
||||
else:
|
||||
shape.label = result[1][0]
|
||||
self.singleLabel(shape)
|
||||
self.setDirty()
|
||||
print(box)
|
||||
for shape in self.canvas.selectedShapes:
|
||||
box = [[int(p.x()), int(p.y())] for p in shape.points]
|
||||
assert len(box) == 4
|
||||
img_crop = get_rotate_crop_image(img, np.array(box, np.float32))
|
||||
if img_crop is None:
|
||||
msg = 'Can not recognise the detection box in ' + self.filePath + '. Please change manually'
|
||||
QMessageBox.information(self, "Information", msg)
|
||||
return
|
||||
result = self.ocr.ocr(img_crop, cls=True, det=False)
|
||||
if result[0][0] != '':
|
||||
result.insert(0, box)
|
||||
print('result in reRec is ', result)
|
||||
if result[1][0] == shape.label:
|
||||
print('label no change')
|
||||
else:
|
||||
shape.label = result[1][0]
|
||||
self.singleLabel(shape)
|
||||
self.setDirty()
|
||||
print(box)
|
||||
|
||||
def autolcm(self):
|
||||
vbox = QVBoxLayout()
|
||||
|
@ -1914,8 +1949,8 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.savePPlabel()
|
||||
|
||||
def saveRecResult(self):
|
||||
if None in [self.PPlabelpath, self.PPlabel, self.fileStatedict]:
|
||||
QMessageBox.information(self, "Information", "Save file first")
|
||||
if {} in [self.PPlabelpath, self.PPlabel, self.fileStatedict]:
|
||||
QMessageBox.information(self, "Information", "Check the image first")
|
||||
return
|
||||
|
||||
rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt'
|
||||
|
@ -1953,6 +1988,33 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
self.canvas.newShape.disconnect()
|
||||
self.canvas.newShape.connect(partial(self.newShape, False))
|
||||
|
||||
def autoSaveFunc(self):
|
||||
if self.autoSaveOption.isChecked():
|
||||
self.autoSaveNum = 1 # Real auto_Save
|
||||
try:
|
||||
self.saveLabelFile()
|
||||
except:
|
||||
pass
|
||||
print('The program will automatically save once after confirming an image')
|
||||
else:
|
||||
self.autoSaveNum = 5 # Used for backup
|
||||
print('The program will automatically save once after confirming 5 images (default)')
|
||||
|
||||
def undoShapeEdit(self):
|
||||
self.canvas.restoreShape()
|
||||
self.labelList.clear()
|
||||
self.BoxList.clear()
|
||||
self.loadShapes(self.canvas.shapes)
|
||||
self.actions.undo.setEnabled(self.canvas.isShapeRestorable)
|
||||
|
||||
def loadShapes(self, shapes, replace=True):
|
||||
self._noSelectionSlot = True
|
||||
for shape in shapes:
|
||||
self.addLabel(shape)
|
||||
self.labelList.clearSelection()
|
||||
self._noSelectionSlot = False
|
||||
self.canvas.loadShapes(shapes, replace=replace)
|
||||
|
||||
|
||||
def inverted(color):
|
||||
return QColor(*[255 - v for v in color.getRgb()])
|
||||
|
|
|
@ -8,15 +8,18 @@ PPOCRLabel is a semi-automatic graphic annotation tool suitable for OCR field, w
|
|||
|
||||
### Recent Update
|
||||
|
||||
- 2021.2.5: New batch processing and undo functions (by [Evezerest](https://github.com/Evezerest)):
|
||||
- Batch processing function: Press and hold the Ctrl key to select the box, you can move, copy, and delete in batches.
|
||||
- Undo function: In the process of drawing a four-point label box or after editing the box, press Ctrl+Z to undo the previous operation.
|
||||
- Fix image rotation and size problems, optimize the process of editing the mark frame (by [ninetailskim](https://github.com/ninetailskim)、 [edencfc](https://github.com/edencfc)).
|
||||
- 2021.1.11: Optimize the labeling experience (by [edencfc](https://github.com/edencfc)),
|
||||
- Users can choose whether to pop up the label input dialog after drawing the detection box in "View - Pop-up Label Input Dialog".
|
||||
- Users can choose whether to pop up the label input dialog after drawing the detection box in "View - Pop-up Label Input Dialog".
|
||||
- The recognition result scrolls synchronously when users click related detection box.
|
||||
- Click to modify the recognition result.(If you can't change the result, please switch to the system default input method, or switch back to the original input method again)
|
||||
- 2020.12.18: Support re-recognition of a single label box (by [ninetailskim](https://github.com/ninetailskim) ), perfect shortcut keys.
|
||||
|
||||
### TODO:
|
||||
- Lock box mode: For the same scene data, the size and position of the locked detection box can be transferred between different pictures.
|
||||
- Experience optimization: Add undo, batch operation include move, copy, delete and so on, optimize the annotation process.
|
||||
|
||||
## Installation
|
||||
|
||||
|
@ -49,7 +52,7 @@ python3 PPOCRLabel.py
|
|||
```
|
||||
pip3 install pyqt5
|
||||
pip3 uninstall opencv-python # Uninstall opencv manually as it conflicts with pyqt
|
||||
pip3 install opencv-contrib-python-headless # Install the headless version of opencv
|
||||
pip3 install opencv-contrib-python-headless==4.2.0.32 # Install the headless version of opencv
|
||||
cd ./PPOCRLabel # Change the directory to the PPOCRLabel folder
|
||||
python3 PPOCRLabel.py
|
||||
```
|
||||
|
@ -76,12 +79,11 @@ python3 PPOCRLabel.py
|
|||
|
||||
7. Double click the result in 'recognition result' list to manually change inaccurate recognition results.
|
||||
|
||||
8. Click "Check", the image status will switch to "√",then the program automatically jump to the next(The results will not be written directly to the file at this time).
|
||||
8. Click "Check", the image status will switch to "√",then the program automatically jump to the next.
|
||||
|
||||
9. Click "Delete Image" and the image will be deleted to the recycle bin.
|
||||
|
||||
10. Labeling result: the user can save manually through the menu "File - Save Label", while the program will also save automatically after every 5 images confirmed by the user.the manually checked label will be stored in *Label.txt* under the opened picture folder.
|
||||
Click "PaddleOCR"-"Save Recognition Results" in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*<sup>[4]</sup>.
|
||||
10. Labeling result: the user can save manually through the menu "File - Save Label", while the program will also save automatically if "File - Auto Save Label Mode" is selected. The manually checked label will be stored in *Label.txt* under the opened picture folder. Click "PaddleOCR"-"Save Recognition Results" in the menu bar, the recognition training data of such pictures will be saved in the *crop_img* folder, and the recognition label will be saved in *rec_gt.txt*<sup>[4]</sup>.
|
||||
|
||||
### Note
|
||||
|
||||
|
@ -89,8 +91,7 @@ python3 PPOCRLabel.py
|
|||
|
||||
[2] The image status indicates whether the user has saved the image manually. If it has not been saved manually it is "X", otherwise it is "√", PPOCRLabel will not relabel pictures with a status of "√".
|
||||
|
||||
[3] After clicking "Re-recognize", the model will overwrite ALL recognition results in the picture.
|
||||
Therefore, if the recognition result has been manually changed before, it may change after re-recognition.
|
||||
[3] After clicking "Re-recognize", the model will overwrite ALL recognition results in the picture. Therefore, if the recognition result has been manually changed before, it may change after re-recognition.
|
||||
|
||||
[4] The files produced by PPOCRLabel can be found under the opened picture folder including the following, please do not manually change the contents, otherwise it will cause the program to be abnormal.
|
||||
|
||||
|
@ -106,28 +107,29 @@ Therefore, if the recognition result has been manually changed before, it may ch
|
|||
|
||||
### Shortcut keys
|
||||
|
||||
| Shortcut keys | Description |
|
||||
| ---------------- | ------------------------------------------------ |
|
||||
| Ctrl + shift + A | Automatically label all unchecked images |
|
||||
| Ctrl + shift + R | Re-recognize all the labels of the current image |
|
||||
| W | Create a rect box |
|
||||
| Q | Create a four-points box |
|
||||
| Ctrl + E | Edit label of the selected box |
|
||||
| Ctrl + R | Re-recognize the selected box |
|
||||
| Backspace | Delete the selected box |
|
||||
| Ctrl + V | Check image |
|
||||
| Ctrl + Shift + d | Delete image |
|
||||
| D | Next image |
|
||||
| A | Previous image |
|
||||
| Ctrl++ | Zoom in |
|
||||
| Ctrl-- | Zoom out |
|
||||
| ↑→↓← | Move selected box |
|
||||
| Shortcut keys | Description |
|
||||
| ------------------------ | ------------------------------------------------ |
|
||||
| Ctrl + Shift + R | Re-recognize all the labels of the current image |
|
||||
| W | Create a rect box |
|
||||
| Q | Create a four-points box |
|
||||
| Ctrl + E | Edit label of the selected box |
|
||||
| Ctrl + R | Re-recognize the selected box |
|
||||
| Ctrl + C | Copy and paste the selected box |
|
||||
| Ctrl + Left Mouse Button | Multi select the label box |
|
||||
| Backspace | Delete the selected box |
|
||||
| Ctrl + V | Check image |
|
||||
| Ctrl + Shift + d | Delete image |
|
||||
| D | Next image |
|
||||
| A | Previous image |
|
||||
| Ctrl++ | Zoom in |
|
||||
| Ctrl-- | Zoom out |
|
||||
| ↑→↓← | Move selected box |
|
||||
|
||||
### Built-in Model
|
||||
|
||||
- Default model: PPOCRLabel uses the Chinese and English ultra-lightweight OCR model in PaddleOCR by default, supports Chinese, English and number recognition, and multiple language detection.
|
||||
|
||||
- Model language switching: Changing the built-in model language is supportable by clicking "PaddleOCR"-"Choose OCR Model" in the menu bar. Currently supported languagesinclude French, German, Korean, and Japanese.
|
||||
- Model language switching: Changing the built-in model language is supportable by clicking "PaddleOCR"-"Choose OCR Model" in the menu bar. Currently supported languagesinclude French, German, Korean, and Japanese.
|
||||
For specific model download links, please refer to [PaddleOCR Model List](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md#multilingual-recognition-modelupdating)
|
||||
|
||||
- Custom model: The model trained by users can be replaced by modifying PPOCRLabel.py in [PaddleOCR class instantiation](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/PPOCRLabel/PPOCRLabel.py#L110) referring [Custom Model Code](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/whl_en.md#use-custom-model)
|
||||
|
@ -136,7 +138,7 @@ Therefore, if the recognition result has been manually changed before, it may ch
|
|||
|
||||
PPOCRLabel supports three ways to save Label.txt
|
||||
|
||||
- Automatically save: When it detects that the user has manually checked 5 pictures, the program automatically writes the annotations into Label.txt. The user can change the value of ``self.autoSaveNum`` in ``PPOCRLabel.py`` to set the number of images to be automatically saved after confirmation.
|
||||
- Automatically save: After selecting "File - Auto Save Label Mode", the program will automatically write the annotations into Label.txt every time the user confirms an image. If this option is not turned on, it will be automatically saved after detecting that the user has manually checked 5 images.
|
||||
- Manual save: Click "File-Save Marking Results" to manually save the label.
|
||||
- Close application save
|
||||
|
||||
|
@ -160,11 +162,11 @@ For some data that are difficult to recognize, the recognition results will not
|
|||
```
|
||||
pyrcc5 -o libs/resources.py resources.qrc
|
||||
```
|
||||
- If you get an error ``` module 'cv2' has no attribute 'INTER_NEAREST'```, you need to delete all opencv related packages first, and then reinstall the headless version of opencv
|
||||
- If you get an error ``` module 'cv2' has no attribute 'INTER_NEAREST'```, you need to delete all opencv related packages first, and then reinstall the 4.2.0.32 version of headless opencv
|
||||
```
|
||||
pip install opencv-contrib-python-headless
|
||||
pip install opencv-contrib-python-headless==4.2.0.32
|
||||
```
|
||||
|
||||
|
||||
### Related
|
||||
|
||||
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
|
||||
1.[Tzutalin. LabelImg. Git code (2015)](https://github.com/tzutalin/labelImg)
|
|
@ -8,6 +8,10 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
|
|||
|
||||
#### 近期更新
|
||||
|
||||
- 2021.2.5:新增批处理与撤销功能(by [Evezerest](https://github.com/Evezerest))
|
||||
- 批处理功能:按住Ctrl键选择标记框后可批量移动、复制、删除。
|
||||
- 撤销功能:在绘制四点标注框过程中或对框进行编辑操作后,按下Ctrl+Z可撤销上一部操作。
|
||||
- 修复图像旋转和尺寸问题、优化编辑标记框过程(by [ninetailskim](https://github.com/ninetailskim)、 [edencfc](https://github.com/edencfc))
|
||||
- 2021.1.11:优化标注体验(by [edencfc](https://github.com/edencfc)):
|
||||
- 用户可在“视图 - 弹出标记输入框”选择在画完检测框后标记输入框是否弹出。
|
||||
- 识别结果与检测框同步滚动。
|
||||
|
@ -17,9 +21,8 @@ PPOCRLabel是一款适用于OCR领域的半自动化图形标注工具,内置P
|
|||
#### 尽请期待
|
||||
|
||||
- 锁定框模式:针对同一场景数据,被锁定的检测框的大小与位置能在不同图片之间传递。
|
||||
- 体验优化:增加撤销操作,批量移动、复制、删除等功能。优化标注流程。
|
||||
|
||||
如果您对以上内容感兴趣或对完善工具有不一样的想法,欢迎加入我们的队伍与我们共同开发
|
||||
如果您对以上内容感兴趣或对完善工具有不一样的想法,欢迎加入我们的SIG队伍与我们共同开发。可以在[此处](https://github.com/PaddlePaddle/PaddleOCR/issues/1728)完成问卷和前置任务,经过我们确认相关内容后即可正式加入,享受SIG福利,共同为OCR开源事业贡献(特别说明:针对PPOCRLabel的改进也属于PaddleOCR前置任务)
|
||||
|
||||
|
||||
## 安装
|
||||
|
@ -49,7 +52,7 @@ python3 PPOCRLabel.py --lang ch
|
|||
```
|
||||
pip3 install pyqt5
|
||||
pip3 uninstall opencv-python # 由于mac版本的opencv与pyqt有冲突,需先手动卸载opencv
|
||||
pip3 install opencv-contrib-python-headless # 安装headless版本的open-cv
|
||||
pip3 install opencv-contrib-python-headless==4.2.0.32 # 安装headless版本的open-cv
|
||||
cd ./PPOCRLabel # 将目录切换到PPOCRLabel文件夹下
|
||||
python3 PPOCRLabel.py --lang ch
|
||||
```
|
||||
|
@ -65,9 +68,9 @@ python3 PPOCRLabel.py --lang ch
|
|||
5. 标记框绘制完成后,用户点击 “确认”,检测框会先被预分配一个 “待识别” 标签。
|
||||
6. 重新识别:将图片中的所有检测画绘制/调整完成后,点击 “重新识别”,PPOCR模型会对当前图片中的**所有检测框**重新识别<sup>[3]</sup>。
|
||||
7. 内容更改:双击识别结果,对不准确的识别结果进行手动更改。
|
||||
8. 确认标记:点击 “确认”,图片状态切换为 “√”,跳转至下一张(此时不会直接将结果写入文件)。
|
||||
8. **确认标记**:点击 “确认”,图片状态切换为 “√”,跳转至下一张。
|
||||
9. 删除:点击 “删除图像”,图片将会被删除至回收站。
|
||||
10. 保存结果:用户可以通过菜单中“文件-保存标记结果”手动保存,同时程序也会在用户每确认5张图片后自动保存一次。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "保存识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*中<sup>[4]</sup>。
|
||||
10. 保存结果:用户可以通过菜单中“文件-保存标记结果”手动保存,同时也可以点击“文件 - 自动保存标记结果”开启自动保存。手动确认过的标记将会被存放在所打开图片文件夹下的*Label.txt*中。在菜单栏点击 “文件” - "保存识别结果"后,会将此类图片的识别训练数据保存在*crop_img*文件夹下,识别标签保存在*rec_gt.txt*中<sup>[4]</sup>。
|
||||
|
||||
### 注意
|
||||
|
||||
|
@ -93,12 +96,13 @@ python3 PPOCRLabel.py --lang ch
|
|||
|
||||
| 快捷键 | 说明 |
|
||||
| ---------------- | ---------------------------- |
|
||||
| Ctrl + shift + A | 自动标注所有未确认过的图片 |
|
||||
| Ctrl + shift + R | 对当前图片的所有标记重新识别 |
|
||||
| W | 新建矩形框 |
|
||||
| Q | 新建四点框 |
|
||||
| Ctrl + E | 编辑所选框标签 |
|
||||
| Ctrl + R | 重新识别所选标记 |
|
||||
| Ctrl + C | 复制并粘贴选中的标记框 |
|
||||
| Ctrl + 鼠标左键 | 多选标记框 |
|
||||
| Backspace | 删除所选框 |
|
||||
| Ctrl + V | 确认本张图片标记 |
|
||||
| Ctrl + Shift + d | 删除本张图片 |
|
||||
|
@ -120,7 +124,7 @@ python3 PPOCRLabel.py --lang ch
|
|||
|
||||
PPOCRLabel支持三种保存方式:
|
||||
|
||||
- 程序自动保存:当检测到用户手动确认过5张图片后,程序自动将标记结果写入Label.txt中。其中用户可通过更改```PPOCRLabel.py```中的```self.autoSaveNum```的数值设置确认几张图片后进行自动保存。
|
||||
- 自动保存:点击“文件 - 自动保存标记结果”后,用户每确认过一张图片,程序自动将标记结果写入Label.txt中。若未开启此选项,则检测到用户手动确认过5张图片后进行自动保存。
|
||||
- 手动保存:点击“文件 - 保存标记结果”手动保存标记。
|
||||
- 关闭应用程序保存
|
||||
|
||||
|
@ -132,22 +136,22 @@ PPOCRLabel支持三种保存方式:
|
|||
|
||||
### 错误提示
|
||||
- 如果同时使用whl包安装了paddleocr,其优先级大于通过paddleocr.py调用PaddleOCR类,whl包未更新时会导致程序异常。
|
||||
|
||||
|
||||
- PPOCRLabel**不支持对中文文件名**的图片进行自动标注。
|
||||
|
||||
- 针对Linux用户:如果您在打开软件过程中出现**objc[XXXXX]**开头的错误,证明您的opencv版本太高,建议安装4.2版本:
|
||||
```
|
||||
pip install opencv-python==4.2.0.32
|
||||
```
|
||||
|
||||
|
||||
- 如果出现 ```Missing string id``` 开头的错误,需要重新编译资源:
|
||||
```
|
||||
pyrcc5 -o libs/resources.py resources.qrc
|
||||
```
|
||||
|
||||
- 如果出现``` module 'cv2' has no attribute 'INTER_NEAREST'```错误,需要首先删除所有opencv相关包,然后重新安装headless版本的opencv
|
||||
|
||||
- 如果出现``` module 'cv2' has no attribute 'INTER_NEAREST'```错误,需要首先删除所有opencv相关包,然后重新安装4.2.0.32版本的headless opencv
|
||||
```
|
||||
pip install opencv-contrib-python-headless
|
||||
pip install opencv-contrib-python-headless==4.2.0.32
|
||||
```
|
||||
|
||||
### 参考资料
|
||||
|
|
|
@ -37,7 +37,8 @@ class Canvas(QWidget):
|
|||
zoomRequest = pyqtSignal(int)
|
||||
scrollRequest = pyqtSignal(int, int)
|
||||
newShape = pyqtSignal()
|
||||
selectionChanged = pyqtSignal(bool)
|
||||
# selectionChanged = pyqtSignal(bool)
|
||||
selectionChanged = pyqtSignal(list)
|
||||
shapeMoved = pyqtSignal()
|
||||
drawingPolygon = pyqtSignal(bool)
|
||||
|
||||
|
@ -51,9 +52,11 @@ class Canvas(QWidget):
|
|||
# Initialise local state.
|
||||
self.mode = self.EDIT
|
||||
self.shapes = []
|
||||
self.shapesBackups = []
|
||||
self.current = None
|
||||
self.selectedShapes = []
|
||||
self.selectedShape = None # save the selected shape here
|
||||
self.selectedShapeCopy = None
|
||||
self.selectedShapesCopy = []
|
||||
self.drawingLineColor = QColor(0, 0, 255)
|
||||
self.drawingRectColor = QColor(0, 0, 255)
|
||||
self.line = Shape(line_color=self.drawingLineColor)
|
||||
|
@ -77,6 +80,7 @@ class Canvas(QWidget):
|
|||
self.drawSquare = False
|
||||
self.fourpoint = True # ADD
|
||||
self.pointnum = 0
|
||||
self.movingShape = False
|
||||
|
||||
#initialisation for panning
|
||||
self.pan_initial_pos = QPoint()
|
||||
|
@ -149,37 +153,20 @@ class Canvas(QWidget):
|
|||
clipped_x = min(max(0, pos.x()), size.width())
|
||||
clipped_y = min(max(0, pos.y()), size.height())
|
||||
pos = QPointF(clipped_x, clipped_y)
|
||||
elif len(self.current) > 1 and self.closeEnough(pos, self.current[0]) and not self.fourpoint:
|
||||
|
||||
elif len(self.current) > 1 and self.closeEnough(pos, self.current[0]):
|
||||
# Attract line to starting point and colorise to alert the
|
||||
# user:
|
||||
pos = self.current[0]
|
||||
color = self.current.line_color
|
||||
self.overrideCursor(CURSOR_POINT)
|
||||
self.current.highlightVertex(0, Shape.NEAR_VERTEX)
|
||||
elif ( # ADD
|
||||
len(self.current) > 1
|
||||
and self.fourpoint
|
||||
and self.closeEnough(pos, self.current[0])
|
||||
):
|
||||
# Attract line to starting point and
|
||||
# colorise to alert the user.
|
||||
pos = self.current[0]
|
||||
self.overrideCursor(CURSOR_POINT)
|
||||
self.current.highlightVertex(0, Shape.NEAR_VERTEX)
|
||||
|
||||
|
||||
if self.drawSquare:
|
||||
initPos = self.current[0]
|
||||
minX = initPos.x()
|
||||
minY = initPos.y()
|
||||
min_size = min(abs(pos.x() - minX), abs(pos.y() - minY))
|
||||
directionX = -1 if pos.x() - minX < 0 else 1
|
||||
directionY = -1 if pos.y() - minY < 0 else 1
|
||||
self.line[1] = QPointF(minX + directionX * min_size, minY + directionY * min_size)
|
||||
self.line.points = [self.current[0], pos]
|
||||
self.line.close()
|
||||
|
||||
elif self.fourpoint:
|
||||
# self.line[self.pointnum] = pos # OLD
|
||||
|
||||
self.line[0] = self.current[-1]
|
||||
self.line[1] = pos
|
||||
|
||||
|
@ -196,12 +183,14 @@ class Canvas(QWidget):
|
|||
|
||||
# Polygon copy moving.
|
||||
if Qt.RightButton & ev.buttons():
|
||||
if self.selectedShapeCopy and self.prevPoint:
|
||||
if self.selectedShapesCopy and self.prevPoint:
|
||||
self.overrideCursor(CURSOR_MOVE)
|
||||
self.boundedMoveShape(self.selectedShapeCopy, pos)
|
||||
self.boundedMoveShape(self.selectedShapesCopy, pos)
|
||||
self.repaint()
|
||||
elif self.selectedShape:
|
||||
self.selectedShapeCopy = self.selectedShape.copy()
|
||||
elif self.selectedShapes:
|
||||
self.selectedShapesCopy = [
|
||||
s.copy() for s in self.selectedShapes
|
||||
]
|
||||
self.repaint()
|
||||
return
|
||||
|
||||
|
@ -211,11 +200,13 @@ class Canvas(QWidget):
|
|||
self.boundedMoveVertex(pos)
|
||||
self.shapeMoved.emit()
|
||||
self.repaint()
|
||||
elif self.selectedShape and self.prevPoint:
|
||||
self.movingShape = True
|
||||
elif self.selectedShapes and self.prevPoint:
|
||||
self.overrideCursor(CURSOR_MOVE)
|
||||
self.boundedMoveShape(self.selectedShape, pos)
|
||||
self.boundedMoveShape(self.selectedShapes, pos)
|
||||
self.shapeMoved.emit()
|
||||
self.repaint()
|
||||
self.movingShape = True
|
||||
else:
|
||||
#pan
|
||||
delta_x = pos.x() - self.pan_initial_pos.x()
|
||||
|
@ -263,65 +254,60 @@ class Canvas(QWidget):
|
|||
|
||||
def mousePressEvent(self, ev):
|
||||
pos = self.transformPos(ev.pos())
|
||||
|
||||
if ev.button() == Qt.LeftButton:
|
||||
if self.drawing():
|
||||
# self.handleDrawing(pos) # OLD
|
||||
|
||||
|
||||
if self.current and self.fourpoint: # ADD IF
|
||||
# Add point to existing shape.
|
||||
print('Adding points in mousePressEvent is ', self.line[1])
|
||||
self.current.addPoint(self.line[1])
|
||||
self.line[0] = self.current[-1]
|
||||
if self.current.isClosed():
|
||||
# print('1111')
|
||||
if self.current:
|
||||
if self.fourpoint: # ADD IF
|
||||
# Add point to existing shape.
|
||||
# print('Adding points in mousePressEvent is ', self.line[1])
|
||||
self.current.addPoint(self.line[1])
|
||||
self.line[0] = self.current[-1]
|
||||
if self.current.isClosed():
|
||||
# print('1111')
|
||||
self.finalise()
|
||||
elif self.drawSquare: # 增加
|
||||
assert len(self.current.points) == 1
|
||||
self.current.points = self.line.points
|
||||
self.finalise()
|
||||
elif not self.outOfPixmap(pos):
|
||||
# Create new shape.
|
||||
self.current = Shape()# self.current = Shape(shape_type=self.createMode)
|
||||
self.current = Shape()
|
||||
self.current.addPoint(pos)
|
||||
# if self.createMode == "point":
|
||||
# self.finalise()
|
||||
# else:
|
||||
# if self.createMode == "circle":
|
||||
# self.current.shape_type = "circle"
|
||||
self.line.points = [pos, pos]
|
||||
self.setHiding()
|
||||
self.drawingPolygon.emit(True)
|
||||
self.update()
|
||||
|
||||
|
||||
else:
|
||||
selection = self.selectShapePoint(pos)
|
||||
group_mode = int(ev.modifiers()) == Qt.ControlModifier
|
||||
self.selectShapePoint(pos, multiple_selection_mode=group_mode)
|
||||
self.prevPoint = pos
|
||||
|
||||
if selection is None:
|
||||
#pan
|
||||
QApplication.setOverrideCursor(QCursor(Qt.OpenHandCursor))
|
||||
self.pan_initial_pos = pos
|
||||
self.pan_initial_pos = pos
|
||||
|
||||
elif ev.button() == Qt.RightButton and self.editing():
|
||||
self.selectShapePoint(pos)
|
||||
group_mode = int(ev.modifiers()) == Qt.ControlModifier
|
||||
self.selectShapePoint(pos, multiple_selection_mode=group_mode)
|
||||
self.prevPoint = pos
|
||||
self.update()
|
||||
|
||||
def mouseReleaseEvent(self, ev):
|
||||
if ev.button() == Qt.RightButton:
|
||||
menu = self.menus[bool(self.selectedShapeCopy)]
|
||||
menu = self.menus[bool(self.selectedShapesCopy)]
|
||||
self.restoreCursor()
|
||||
if not menu.exec_(self.mapToGlobal(ev.pos()))\
|
||||
and self.selectedShapeCopy:
|
||||
and self.selectedShapesCopy:
|
||||
# Cancel the move by deleting the shadow copy.
|
||||
self.selectedShapeCopy = None
|
||||
# self.selectedShapeCopy = None
|
||||
self.selectedShapesCopy = []
|
||||
self.repaint()
|
||||
elif ev.button() == Qt.LeftButton and self.selectedShape: # OLD
|
||||
|
||||
elif ev.button() == Qt.LeftButton and self.selectedShapes:
|
||||
if self.selectedVertex():
|
||||
self.overrideCursor(CURSOR_POINT)
|
||||
else:
|
||||
self.overrideCursor(CURSOR_GRAB)
|
||||
|
||||
|
||||
elif ev.button() == Qt.LeftButton and not self.fourpoint:
|
||||
pos = self.transformPos(ev.pos())
|
||||
if self.drawing():
|
||||
|
@ -330,24 +316,37 @@ class Canvas(QWidget):
|
|||
#pan
|
||||
QApplication.restoreOverrideCursor() # ?
|
||||
|
||||
if self.movingShape and self.hShape:
|
||||
index = self.shapes.index(self.hShape)
|
||||
if (
|
||||
self.shapesBackups[-1][index].points
|
||||
!= self.shapes[index].points
|
||||
):
|
||||
self.storeShapes()
|
||||
self.shapeMoved.emit() # connect to updateBoxlist in PPOCRLabel.py
|
||||
|
||||
self.movingShape = False
|
||||
|
||||
|
||||
def endMove(self, copy=False):
|
||||
assert self.selectedShape and self.selectedShapeCopy
|
||||
shape = self.selectedShapeCopy
|
||||
#del shape.fill_color
|
||||
#del shape.line_color
|
||||
assert self.selectedShapes and self.selectedShapesCopy
|
||||
assert len(self.selectedShapesCopy) == len(self.selectedShapes)
|
||||
if copy:
|
||||
self.shapes.append(shape)
|
||||
self.selectedShape.selected = False
|
||||
self.selectedShape = shape
|
||||
self.repaint()
|
||||
for i, shape in enumerate(self.selectedShapesCopy):
|
||||
self.shapes.append(shape)
|
||||
self.selectedShapes[i].selected = False
|
||||
self.selectedShapes[i] = shape
|
||||
else:
|
||||
self.selectedShape.points = [p for p in shape.points]
|
||||
self.selectedShapeCopy = None
|
||||
for i, shape in enumerate(self.selectedShapesCopy):
|
||||
self.selectedShapes[i].points = shape.points
|
||||
self.selectedShapesCopy = []
|
||||
self.repaint()
|
||||
self.storeShapes()
|
||||
return True
|
||||
|
||||
def hideBackroundShapes(self, value):
|
||||
self.hideBackround = value
|
||||
if self.selectedShape:
|
||||
if self.selectedShapes:
|
||||
# Only hide other shapes if there is a current selection.
|
||||
# Otherwise the user will not be able to select a shape.
|
||||
self.setHiding(True)
|
||||
|
@ -363,7 +362,7 @@ class Canvas(QWidget):
|
|||
if self.pointnum == 3:
|
||||
self.finalise()
|
||||
|
||||
else: # 按住送掉后跳到这里
|
||||
else:
|
||||
initPos = self.current[0]
|
||||
print('initPos', self.current[0])
|
||||
minX = initPos.x()
|
||||
|
@ -399,28 +398,33 @@ class Canvas(QWidget):
|
|||
self.current.popPoint()
|
||||
self.finalise()
|
||||
|
||||
def selectShape(self, shape):
|
||||
self.deSelectShape()
|
||||
shape.selected = True
|
||||
self.selectedShape = shape
|
||||
def selectShapes(self, shapes):
|
||||
for s in shapes: s.seleted = True
|
||||
self.setHiding()
|
||||
self.selectionChanged.emit(True)
|
||||
self.selectionChanged.emit(shapes)
|
||||
self.update()
|
||||
|
||||
def selectShapePoint(self, point):
|
||||
|
||||
def selectShapePoint(self, point, multiple_selection_mode):
|
||||
"""Select the first shape created which contains this point."""
|
||||
self.deSelectShape()
|
||||
if self.selectedVertex(): # A vertex is marked for selection.
|
||||
index, shape = self.hVertex, self.hShape
|
||||
shape.highlightVertex(index, shape.MOVE_VERTEX)
|
||||
self.selectShape(shape)
|
||||
return self.hVertex
|
||||
for shape in reversed(self.shapes):
|
||||
if self.isVisible(shape) and shape.containsPoint(point):
|
||||
self.selectShape(shape)
|
||||
self.calculateOffsets(shape, point)
|
||||
return self.selectedShape
|
||||
return None
|
||||
else:
|
||||
for shape in reversed(self.shapes):
|
||||
if self.isVisible(shape) and shape.containsPoint(point):
|
||||
self.calculateOffsets(shape, point)
|
||||
self.setHiding()
|
||||
if multiple_selection_mode:
|
||||
if shape not in self.selectedShapes: # list
|
||||
self.selectionChanged.emit(
|
||||
self.selectedShapes + [shape]
|
||||
)
|
||||
else:
|
||||
self.selectionChanged.emit([shape])
|
||||
return
|
||||
self.deSelectShape()
|
||||
|
||||
def calculateOffsets(self, shape, point):
|
||||
rect = shape.boundingRect()
|
||||
|
@ -465,22 +469,28 @@ class Canvas(QWidget):
|
|||
else:
|
||||
shiftPos = pos - point
|
||||
|
||||
shape.moveVertexBy(index, shiftPos)
|
||||
if [shape[0].x(), shape[0].y(), shape[2].x(), shape[2].y()] \
|
||||
== [shape[3].x(),shape[1].y(),shape[1].x(),shape[3].y()]:
|
||||
shape.moveVertexBy(index, shiftPos)
|
||||
lindex = (index + 1) % 4
|
||||
rindex = (index + 3) % 4
|
||||
lshift = None
|
||||
rshift = None
|
||||
if index % 2 == 0:
|
||||
rshift = QPointF(shiftPos.x(), 0)
|
||||
lshift = QPointF(0, shiftPos.y())
|
||||
else:
|
||||
lshift = QPointF(shiftPos.x(), 0)
|
||||
rshift = QPointF(0, shiftPos.y())
|
||||
shape.moveVertexBy(rindex, rshift)
|
||||
shape.moveVertexBy(lindex, lshift)
|
||||
|
||||
lindex = (index + 1) % 4
|
||||
rindex = (index + 3) % 4
|
||||
lshift = None
|
||||
rshift = None
|
||||
if index % 2 == 0:
|
||||
rshift = QPointF(shiftPos.x(), 0)
|
||||
lshift = QPointF(0, shiftPos.y())
|
||||
else:
|
||||
lshift = QPointF(shiftPos.x(), 0)
|
||||
rshift = QPointF(0, shiftPos.y())
|
||||
shape.moveVertexBy(rindex, rshift)
|
||||
shape.moveVertexBy(lindex, lshift)
|
||||
shape.moveVertexBy(index, shiftPos)
|
||||
|
||||
def boundedMoveShape(self, shape, pos):
|
||||
|
||||
def boundedMoveShape(self, shapes, pos):
|
||||
if type(shapes).__name__ != 'list': shapes = [shapes]
|
||||
if self.outOfPixmap(pos):
|
||||
return False # No need to move
|
||||
o1 = pos + self.offsets[0]
|
||||
|
@ -497,46 +507,55 @@ class Canvas(QWidget):
|
|||
#self.calculateOffsets(self.selectedShape, pos)
|
||||
dp = pos - self.prevPoint
|
||||
if dp:
|
||||
shape.moveBy(dp)
|
||||
for shape in shapes:
|
||||
shape.moveBy(dp)
|
||||
self.prevPoint = pos
|
||||
return True
|
||||
return False
|
||||
|
||||
def deSelectShape(self):
|
||||
if self.selectedShape:
|
||||
self.selectedShape.selected = False
|
||||
self.selectedShape = None
|
||||
if self.selectedShapes:
|
||||
for shape in self.selectedShapes: shape.selected=False
|
||||
self.setHiding(False)
|
||||
self.selectionChanged.emit(False)
|
||||
self.selectionChanged.emit([])
|
||||
self.update()
|
||||
|
||||
def deleteSelected(self):
|
||||
if self.selectedShape:
|
||||
shape = self.selectedShape
|
||||
self.shapes.remove(self.selectedShape)
|
||||
self.selectedShape = None
|
||||
deleted_shapes = []
|
||||
if self.selectedShapes:
|
||||
for shape in self.selectedShapes:
|
||||
self.shapes.remove(shape)
|
||||
deleted_shapes.append(shape)
|
||||
self.storeShapes()
|
||||
self.selectedShapes = []
|
||||
self.update()
|
||||
return shape
|
||||
return deleted_shapes
|
||||
|
||||
def storeShapes(self):
|
||||
shapesBackup = []
|
||||
for shape in self.shapes:
|
||||
shapesBackup.append(shape.copy())
|
||||
if len(self.shapesBackups) >= 10:
|
||||
self.shapesBackups = self.shapesBackups[-9:]
|
||||
self.shapesBackups.append(shapesBackup)
|
||||
|
||||
def copySelectedShape(self):
|
||||
if self.selectedShape:
|
||||
shape = self.selectedShape.copy()
|
||||
self.deSelectShape()
|
||||
self.shapes.append(shape)
|
||||
shape.selected = True
|
||||
self.selectedShape = shape
|
||||
self.boundedShiftShape(shape)
|
||||
return shape
|
||||
if self.selectedShapes:
|
||||
self.selectedShapesCopy = [s.copy() for s in self.selectedShapes]
|
||||
self.boundedShiftShapes(self.selectedShapesCopy)
|
||||
self.endMove(copy=True)
|
||||
return self.selectedShapes
|
||||
|
||||
def boundedShiftShape(self, shape):
|
||||
def boundedShiftShapes(self, shapes):
|
||||
# Try to move in one direction, and if it fails in another.
|
||||
# Give up if both fail.
|
||||
point = shape[0]
|
||||
offset = QPointF(2.0, 2.0)
|
||||
self.calculateOffsets(shape, point)
|
||||
self.prevPoint = point
|
||||
if not self.boundedMoveShape(shape, point - offset):
|
||||
self.boundedMoveShape(shape, point + offset)
|
||||
for shape in shapes:
|
||||
point = shape[0]
|
||||
offset = QPointF(2.0, 2.0)
|
||||
self.calculateOffsets(shape, point)
|
||||
self.prevPoint = point
|
||||
if not self.boundedMoveShape(shape, point - offset):
|
||||
self.boundedMoveShape(shape, point + offset)
|
||||
|
||||
def paintEvent(self, event):
|
||||
if not self.pixmap:
|
||||
|
@ -560,8 +579,9 @@ class Canvas(QWidget):
|
|||
if self.current:
|
||||
self.current.paint(p)
|
||||
self.line.paint(p)
|
||||
if self.selectedShapeCopy:
|
||||
self.selectedShapeCopy.paint(p)
|
||||
if self.selectedShapesCopy:
|
||||
for s in self.selectedShapesCopy:
|
||||
s.paint(p)
|
||||
|
||||
# Paint rect
|
||||
if self.current is not None and len(self.line) == 2 and not self.fourpoint:
|
||||
|
@ -690,13 +710,13 @@ class Canvas(QWidget):
|
|||
elif key == Qt.Key_Return and self.canCloseShape():
|
||||
self.finalise()
|
||||
elif key == Qt.Key_Left and self.selectedShape:
|
||||
self.moveOnePixel('Left')
|
||||
self.moveOnePixel('Left')
|
||||
elif key == Qt.Key_Right and self.selectedShape:
|
||||
self.moveOnePixel('Right')
|
||||
self.moveOnePixel('Right')
|
||||
elif key == Qt.Key_Up and self.selectedShape:
|
||||
self.moveOnePixel('Up')
|
||||
self.moveOnePixel('Up')
|
||||
elif key == Qt.Key_Down and self.selectedShape:
|
||||
self.moveOnePixel('Down')
|
||||
self.moveOnePixel('Down')
|
||||
|
||||
def moveOnePixel(self, direction):
|
||||
# print(self.selectedShape.points)
|
||||
|
@ -739,6 +759,7 @@ class Canvas(QWidget):
|
|||
|
||||
if fill_color:
|
||||
self.shapes[-1].fill_color = fill_color
|
||||
self.storeShapes()
|
||||
|
||||
return self.shapes[-1]
|
||||
|
||||
|
@ -749,6 +770,17 @@ class Canvas(QWidget):
|
|||
self.line.points = [self.current[-1], self.current[0]]
|
||||
self.drawingPolygon.emit(True)
|
||||
|
||||
def undoLastPoint(self):
|
||||
if not self.current or self.current.isClosed():
|
||||
return
|
||||
self.current.popPoint()
|
||||
if len(self.current) > 0:
|
||||
self.line[0] = self.current[-1]
|
||||
else:
|
||||
self.current = None
|
||||
self.drawingPolygon.emit(False)
|
||||
self.repaint()
|
||||
|
||||
def resetAllLines(self):
|
||||
assert self.shapes
|
||||
self.current = self.shapes.pop()
|
||||
|
@ -762,11 +794,18 @@ class Canvas(QWidget):
|
|||
def loadPixmap(self, pixmap):
|
||||
self.pixmap = pixmap
|
||||
self.shapes = []
|
||||
self.repaint() # 这函数在哪
|
||||
self.repaint()
|
||||
|
||||
def loadShapes(self, shapes):
|
||||
self.shapes = list(shapes)
|
||||
def loadShapes(self, shapes, replace=True):
|
||||
if replace:
|
||||
self.shapes = list(shapes)
|
||||
else:
|
||||
self.shapes.extend(shapes)
|
||||
self.current = None
|
||||
self.hShape = None
|
||||
self.hVertex = None
|
||||
# self.hEdge = None
|
||||
self.storeShapes()
|
||||
self.repaint()
|
||||
|
||||
def setShapeVisible(self, shape, value):
|
||||
|
@ -793,6 +832,24 @@ class Canvas(QWidget):
|
|||
self.restoreCursor()
|
||||
self.pixmap = None
|
||||
self.update()
|
||||
self.shapesBackups = []
|
||||
|
||||
def setDrawingShapeToSquare(self, status):
|
||||
self.drawSquare = status
|
||||
|
||||
def restoreShape(self):
|
||||
if not self.isShapeRestorable:
|
||||
return
|
||||
self.shapesBackups.pop() # latest
|
||||
shapesBackup = self.shapesBackups.pop()
|
||||
self.shapes = shapesBackup
|
||||
self.selectedShapes = []
|
||||
for shape in self.shapes:
|
||||
shape.selected = False
|
||||
self.repaint()
|
||||
|
||||
@property
|
||||
def isShapeRestorable(self):
|
||||
if len(self.shapesBackups) < 2:
|
||||
return False
|
||||
return True
|
|
@ -82,7 +82,7 @@ class Shape(object):
|
|||
return False
|
||||
|
||||
def addPoint(self, point):
|
||||
if not self.reachMaxPoints():
|
||||
if not self.reachMaxPoints(): # 4个点时发出close信号
|
||||
self.points.append(point)
|
||||
|
||||
def popPoint(self):
|
||||
|
|
|
@ -96,4 +96,7 @@ hideBox=隐藏所有标注
|
|||
showBox=显示所有标注
|
||||
saveLabel=保存标记结果
|
||||
singleRe=重识别此区块
|
||||
labelDialogOption=弹出标记输入框
|
||||
labelDialogOption=弹出标记输入框
|
||||
undo=撤销
|
||||
undoLastPoint=撤销上个点
|
||||
autoSaveMode=自动保存标记结果
|
|
@ -96,4 +96,7 @@ hideBox=Hide All Box
|
|||
showBox=Show All Box
|
||||
saveLabel=Save Label
|
||||
singleRe=Re-recognition RectBox
|
||||
labelDialogOption=Pop-up Label Input Dialog
|
||||
labelDialogOption=Pop-up Label Input Dialog
|
||||
undo=Undo
|
||||
undoLastPoint=Undo Last Point
|
||||
autoSaveMode=Auto Save Label Mode
|
|
@ -42,7 +42,7 @@ The above pictures are the visualizations of the general ppocr_server model. For
|
|||
- Scan the QR code below with your Wechat, you can access to official technical exchange group. Look forward to your participation.
|
||||
|
||||
<div align="center">
|
||||
<img src="./doc/joinus.PNG" width = "200" height = "200" />
|
||||
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/joinus.PNG" width = "200" height = "200" />
|
||||
</div>
|
||||
|
||||
|
||||
|
@ -93,7 +93,7 @@ For a new language request, please refer to [Guideline for new language_requests
|
|||
- [Quick Inference Based on PIP](./doc/doc_en/whl_en.md)
|
||||
- [Python Inference](./doc/doc_en/inference_en.md)
|
||||
- [C++ Inference](./deploy/cpp_infer/readme_en.md)
|
||||
- [Serving](./deploy/hubserving/readme_en.md)
|
||||
- [Serving](./deploy/pdserving/README.md)
|
||||
- [Mobile](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme_en.md)
|
||||
- [Benchmark](./doc/doc_en/benchmark_en.md)
|
||||
- Data Annotation and Synthesis
|
||||
|
|
|
@ -46,7 +46,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
|
|||
- 微信扫描二维码加入官方交流群,获得更高效的问题答疑,与各行各业开发者充分交流,期待您的加入。
|
||||
|
||||
<div align="center">
|
||||
<img src="./doc/joinus.PNG" width = "200" height = "200" />
|
||||
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/release/2.0/doc/joinus.PNG" width = "200" height = "200" />
|
||||
</div>
|
||||
|
||||
## 快速体验
|
||||
|
@ -88,7 +88,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
|
|||
- [基于pip安装whl包快速推理](./doc/doc_ch/whl.md)
|
||||
- [基于Python脚本预测引擎推理](./doc/doc_ch/inference.md)
|
||||
- [基于C++预测引擎推理](./deploy/cpp_infer/readme.md)
|
||||
- [服务化部署](./deploy/hubserving/readme.md)
|
||||
- [服务化部署](./deploy/pdserving/README_CN.md)
|
||||
- [端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/deploy/lite/readme.md)
|
||||
- [Benchmark](./doc/doc_ch/benchmark.md)
|
||||
- 数据集
|
||||
|
|
|
@ -38,7 +38,15 @@ class StyleTextRecPredictor(object):
|
|||
self.std = config["Predictor"]["std"]
|
||||
self.expand_result = config["Predictor"]["expand_result"]
|
||||
|
||||
def predict(self, style_input, text_input):
|
||||
def reshape_to_same_height(self, img_list):
|
||||
h = img_list[0].shape[0]
|
||||
for idx in range(1, len(img_list)):
|
||||
new_w = round(1.0 * img_list[idx].shape[1] /
|
||||
img_list[idx].shape[0] * h)
|
||||
img_list[idx] = cv2.resize(img_list[idx], (new_w, h))
|
||||
return img_list
|
||||
|
||||
def predict_single_image(self, style_input, text_input):
|
||||
style_input = self.rep_style_input(style_input, text_input)
|
||||
tensor_style_input = self.preprocess(style_input)
|
||||
tensor_text_input = self.preprocess(text_input)
|
||||
|
@ -64,6 +72,21 @@ class StyleTextRecPredictor(object):
|
|||
"fake_bg": fake_bg,
|
||||
}
|
||||
|
||||
def predict(self, style_input, text_input_list):
|
||||
if not isinstance(text_input_list, (tuple, list)):
|
||||
return self.predict_single_image(style_input, text_input_list)
|
||||
|
||||
synth_result_list = []
|
||||
for text_input in text_input_list:
|
||||
synth_result = self.predict_single_image(style_input, text_input)
|
||||
synth_result_list.append(synth_result)
|
||||
|
||||
for key in synth_result:
|
||||
res = [r[key] for r in synth_result_list]
|
||||
res = self.reshape_to_same_height(res)
|
||||
synth_result[key] = np.concatenate(res, axis=1)
|
||||
return synth_result
|
||||
|
||||
def preprocess(self, img):
|
||||
img = (img.astype('float32') * self.scale - self.mean) / self.std
|
||||
img_height, img_width, channel = img.shape
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
from utils.config import ArgsParser, load_config, override_config
|
||||
from utils.logging import get_logger
|
||||
|
@ -36,8 +38,9 @@ class ImageSynthesiser(object):
|
|||
self.predictor = getattr(predictors, predictor_method)(self.config)
|
||||
|
||||
def synth_image(self, corpus, style_input, language="en"):
|
||||
corpus, text_input = self.text_drawer.draw_text(corpus, language)
|
||||
synth_result = self.predictor.predict(style_input, text_input)
|
||||
corpus_list, text_input_list = self.text_drawer.draw_text(
|
||||
corpus, language, style_input_width=style_input.shape[1])
|
||||
synth_result = self.predictor.predict(style_input, text_input_list)
|
||||
return synth_result
|
||||
|
||||
|
||||
|
@ -59,12 +62,15 @@ class DatasetSynthesiser(ImageSynthesiser):
|
|||
for i in range(self.output_num):
|
||||
style_data = self.style_sampler.sample()
|
||||
style_input = style_data["image"]
|
||||
corpus_language, text_input_label = self.corpus_generator.generate(
|
||||
)
|
||||
text_input_label, text_input = self.text_drawer.draw_text(
|
||||
text_input_label, corpus_language)
|
||||
corpus_language, text_input_label = self.corpus_generator.generate()
|
||||
text_input_label_list, text_input_list = self.text_drawer.draw_text(
|
||||
text_input_label,
|
||||
corpus_language,
|
||||
style_input_width=style_input.shape[1])
|
||||
|
||||
synth_result = self.predictor.predict(style_input, text_input)
|
||||
text_input_label = "".join(text_input_label_list)
|
||||
|
||||
synth_result = self.predictor.predict(style_input, text_input_list)
|
||||
fake_fusion = synth_result["fake_fusion"]
|
||||
self.writer.save_image(fake_fusion, text_input_label)
|
||||
self.writer.save_label()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from PIL import Image, ImageDraw, ImageFont
|
||||
import numpy as np
|
||||
import cv2
|
||||
from utils.logging import get_logger
|
||||
|
||||
|
||||
|
@ -28,7 +29,11 @@ class StdTextDrawer(object):
|
|||
else:
|
||||
return int((self.height - 4)**2 / font_height)
|
||||
|
||||
def draw_text(self, corpus, language="en", crop=True):
|
||||
def draw_text(self,
|
||||
corpus,
|
||||
language="en",
|
||||
crop=True,
|
||||
style_input_width=None):
|
||||
if language not in self.support_languages:
|
||||
self.logger.warning(
|
||||
"language {} not supported, use en instead.".format(language))
|
||||
|
@ -37,21 +42,43 @@ class StdTextDrawer(object):
|
|||
width = min(self.max_width, len(corpus) * self.height) + 4
|
||||
else:
|
||||
width = len(corpus) * self.height + 4
|
||||
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
||||
draw = ImageDraw.Draw(bg)
|
||||
|
||||
char_x = 2
|
||||
font = self.font_dict[language]
|
||||
for i, char_i in enumerate(corpus):
|
||||
char_size = font.getsize(char_i)[0]
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
if char_x >= width:
|
||||
corpus = corpus[0:i + 1]
|
||||
self.logger.warning("corpus length exceed limit: {}".format(
|
||||
corpus))
|
||||
if style_input_width is not None:
|
||||
width = min(width, style_input_width)
|
||||
|
||||
corpus_list = []
|
||||
text_input_list = []
|
||||
|
||||
while len(corpus) != 0:
|
||||
bg = Image.new("RGB", (width, self.height), color=(127, 127, 127))
|
||||
draw = ImageDraw.Draw(bg)
|
||||
char_x = 2
|
||||
font = self.font_dict[language]
|
||||
i = 0
|
||||
while i < len(corpus):
|
||||
char_i = corpus[i]
|
||||
char_size = font.getsize(char_i)[0]
|
||||
# split when char_x exceeds char size and index is not 0 (at least 1 char should be wroten on the image)
|
||||
if char_x + char_size >= width and i != 0:
|
||||
text_input = np.array(bg).astype(np.uint8)
|
||||
text_input = text_input[:, 0:char_x, :]
|
||||
|
||||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
break
|
||||
draw.text((char_x, 2), char_i, fill=(0, 0, 0), font=font)
|
||||
char_x += char_size
|
||||
|
||||
i += 1
|
||||
# the whole text is shorter than style input
|
||||
if i == len(corpus):
|
||||
text_input = np.array(bg).astype(np.uint8)
|
||||
text_input = text_input[:, 0:char_x, :]
|
||||
|
||||
corpus_list.append(corpus[0:i])
|
||||
text_input_list.append(text_input)
|
||||
corpus = corpus[i:]
|
||||
break
|
||||
|
||||
text_input = np.array(bg).astype(np.uint8)
|
||||
text_input = text_input[:, 0:char_x, :]
|
||||
return corpus, text_input
|
||||
return corpus_list, text_input_list
|
||||
|
|
|
@ -14,12 +14,13 @@ Global:
|
|||
load_static_weights: True
|
||||
cal_metric_during_train: False
|
||||
pretrained_model: ./pretrain_models/ResNet50_vd_ssld_pretrained/
|
||||
checkpoints:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
infer_img:
|
||||
save_res_path: ./output/sast_r50_vd_ic15/predicts_sast.txt
|
||||
|
||||
|
||||
Architecture:
|
||||
model_type: det
|
||||
algorithm: SAST
|
||||
|
|
|
@ -0,0 +1,114 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 600
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/pgnet_r50_vd_totaltext/
|
||||
save_epoch_step: 10
|
||||
# evaluation is run every 0 iterationss after the 1000th iteration
|
||||
eval_batch_step: [ 0, 1000 ]
|
||||
# 1. If pretrained_model is saved in static mode, such as classification pretrained model
|
||||
# from static branch, load_static_weights must be set as True.
|
||||
# 2. If you want to finetune the pretrained models we provide in the docs,
|
||||
# you should set load_static_weights as False.
|
||||
load_static_weights: False
|
||||
cal_metric_during_train: False
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
|
||||
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
character_type: EN
|
||||
max_text_length: 50 # the max length in seq
|
||||
max_text_nums: 30 # the max seq nums in a pic
|
||||
tcl_len: 64
|
||||
|
||||
Architecture:
|
||||
model_type: e2e
|
||||
algorithm: PGNet
|
||||
Transform:
|
||||
Backbone:
|
||||
name: ResNet
|
||||
layers: 50
|
||||
Neck:
|
||||
name: PGFPN
|
||||
Head:
|
||||
name: PGHead
|
||||
|
||||
Loss:
|
||||
name: PGLoss
|
||||
tcl_bs: 64
|
||||
max_text_length: 50 # the same as Global: max_text_length
|
||||
max_text_nums: 30 # the same as Global:max_text_nums
|
||||
pad_num: 36 # the length of dict for pad
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0
|
||||
|
||||
|
||||
PostProcess:
|
||||
name: PGPostProcess
|
||||
score_thresh: 0.5
|
||||
Metric:
|
||||
name: E2EMetric
|
||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||
main_indicator: f_score_e2e
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
label_file_list: [.././train_data/total_text/train/]
|
||||
ratio_list: [1.0]
|
||||
data_format: icdar #two data format: icdar/textnet
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- PGProcessTrain:
|
||||
batch_size: 14 # same as loader: batch_size_per_card
|
||||
min_crop_size: 24
|
||||
min_text_size: 4
|
||||
max_text_size: 512
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
drop_last: True
|
||||
batch_size_per_card: 14
|
||||
num_workers: 16
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: PGDataSet
|
||||
data_dir: ./train_data/
|
||||
label_file_list: [./train_data/total_text/test/]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: RGB
|
||||
channel_first: False
|
||||
- E2ELabelEncode:
|
||||
- E2EResizeForTest:
|
||||
max_side_len: 768
|
||||
- NormalizeImage:
|
||||
scale: 1./255.
|
||||
mean: [ 0.485, 0.456, 0.406 ]
|
||||
std: [ 0.229, 0.224, 0.225 ]
|
||||
order: 'hwc'
|
||||
- ToCHWImage:
|
||||
- KeepKeys:
|
||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags' ]
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 1 # must be 1
|
||||
num_workers: 2
|
|
@ -131,7 +131,7 @@ if __name__ == '__main__':
|
|||
if FLAGS.val:
|
||||
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
|
||||
eval_label_path = os.path.join(project_path,FLAGS.val)
|
||||
loss_file(Eval_label_path)
|
||||
loss_file(eval_label_path)
|
||||
if FLAGS.dict:
|
||||
global_config['Global']['character_dict_path'] = FLAGS.dict
|
||||
dict_path = os.path.join(project_path,FLAGS.dict)
|
||||
|
|
|
@ -65,7 +65,7 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ../training/
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -84,7 +84,7 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ../validation/
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
|
|
@ -64,7 +64,7 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ../training/
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -83,7 +83,7 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ../validation/
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
|
|
@ -58,7 +58,7 @@ Metric:
|
|||
Train:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/srn_train_data_duiqi
|
||||
data_dir: ./train_data/data_lmdb_release/training/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
@ -83,7 +83,7 @@ Train:
|
|||
Eval:
|
||||
dataset:
|
||||
name: LMDBDataSet
|
||||
data_dir: ./train_data/data_lmdb_release/evaluation
|
||||
data_dir: ./train_data/data_lmdb_release/validation/
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
|
|
|
@ -133,7 +133,11 @@ if(WITH_MKL)
|
|||
endif ()
|
||||
endif()
|
||||
else()
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
if (WIN32)
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/openblas${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
else ()
|
||||
set(MATH_LIB ${PADDLE_LIB}/third_party/install/openblas/lib/libopenblas${CMAKE_STATIC_LIBRARY_SUFFIX})
|
||||
endif ()
|
||||
endif()
|
||||
|
||||
# Note: libpaddle_inference_api.so/a must put before libpaddle_fluid.so/a
|
||||
|
@ -157,7 +161,7 @@ endif(WITH_STATIC_LIB)
|
|||
|
||||
if (NOT WIN32)
|
||||
set(DEPS ${DEPS}
|
||||
${MATH_LIB} ${MKLDNN_LIB}
|
||||
${MATH_LIB} ${MKLDNN_LIB}
|
||||
glog gflags protobuf z xxhash
|
||||
)
|
||||
if(EXISTS "${PADDLE_LIB}/third_party/install/snappystream/lib")
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# 服务器端C++预测
|
||||
|
||||
本教程将介绍在服务器端部署PaddleOCR超轻量中文检测、识别模型的详细步骤。
|
||||
本章节介绍PaddleOCR 模型的的C++部署方法,与之对应的python预测部署方式参考[文档](../../doc/doc_ch/inference.md)。
|
||||
C++在性能计算上优于python,因此,在大多数CPU、GPU部署场景,多采用C++的部署方式,本节将介绍如何在Linux\Windows (CPU\GPU)环境下配置C++环境并完成
|
||||
PaddleOCR模型部署。
|
||||
|
||||
|
||||
## 1. 准备环境
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
# Server-side C++ inference
|
||||
|
||||
|
||||
In this tutorial, we will introduce the detailed steps of deploying PaddleOCR ultra-lightweight Chinese detection and recognition models on the server side.
|
||||
This chapter introduces the C++ deployment method of the PaddleOCR model, and the corresponding python predictive deployment method refers to [document](../../doc/doc_ch/inference.md).
|
||||
C++ is better than python in terms of performance calculation. Therefore, in most CPU and GPU deployment scenarios, C++ deployment is mostly used.
|
||||
This section will introduce how to configure the C++ environment and complete it in the Linux\Windows (CPU\GPU) environment
|
||||
PaddleOCR model deployment.
|
||||
|
||||
|
||||
## 1. Prepare the environment
|
||||
|
|
|
@ -50,6 +50,11 @@ int main(int argc, char **argv) {
|
|||
|
||||
cv::Mat srcimg = cv::imread(img_path, cv::IMREAD_COLOR);
|
||||
|
||||
if (!srcimg.data) {
|
||||
std::cerr << "[ERROR] image read failed! image path: " << img_path << "\n";
|
||||
exit(1);
|
||||
}
|
||||
|
||||
DBDetector det(config.det_model_dir, config.use_gpu, config.gpu_id,
|
||||
config.gpu_mem, config.cpu_math_library_num_threads,
|
||||
config.use_mkldnn, config.max_side_len, config.det_db_thresh,
|
||||
|
|
|
@ -76,7 +76,7 @@ void CRNNRecognizer::Run(std::vector<std::vector<std::vector<int>>> boxes,
|
|||
float(*std::max_element(&predict_batch[n * predict_shape[2]],
|
||||
&predict_batch[(n + 1) * predict_shape[2]]));
|
||||
|
||||
if (argmax_idx > 0 && (!(i > 0 && argmax_idx == last_index))) {
|
||||
if (argmax_idx > 0 && (!(n > 0 && argmax_idx == last_index))) {
|
||||
score += max_value;
|
||||
count += 1;
|
||||
str_res.push_back(label_list_[argmax_idx]);
|
||||
|
|
|
@ -9,7 +9,7 @@ use_mkldnn 0
|
|||
max_side_len 960
|
||||
det_db_thresh 0.3
|
||||
det_db_box_thresh 0.5
|
||||
det_db_unclip_ratio 2.0
|
||||
det_db_unclip_ratio 1.6
|
||||
det_model_dir ./inference/ch_ppocr_mobile_v2.0_det_infer/
|
||||
|
||||
# cls config
|
||||
|
|
|
@ -20,7 +20,8 @@ def read_params():
|
|||
#DB parmas
|
||||
cfg.det_db_thresh = 0.3
|
||||
cfg.det_db_box_thresh = 0.5
|
||||
cfg.det_db_unclip_ratio = 2.0
|
||||
cfg.det_db_unclip_ratio = 1.6
|
||||
cfg.use_dilation = False
|
||||
|
||||
# #EAST parmas
|
||||
# cfg.det_east_score_thresh = 0.8
|
||||
|
|
|
@ -20,7 +20,8 @@ def read_params():
|
|||
#DB parmas
|
||||
cfg.det_db_thresh = 0.3
|
||||
cfg.det_db_box_thresh = 0.5
|
||||
cfg.det_db_unclip_ratio = 2.0
|
||||
cfg.det_db_unclip_ratio = 1.6
|
||||
cfg.use_dilation = False
|
||||
|
||||
#EAST parmas
|
||||
cfg.det_east_score_thresh = 0.8
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
PaddleOCR提供2种服务部署方式:
|
||||
- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",按照本教程使用;
|
||||
- (coming soon)基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/readme.md)。
|
||||
- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",使用方法参考[文档](../../deploy/pdserving/README_CN.md)。
|
||||
|
||||
# 基于PaddleHub Serving的服务部署
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ English | [简体中文](readme.md)
|
|||
|
||||
PaddleOCR provides 2 service deployment methods:
|
||||
- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please follow this tutorial.
|
||||
- (coming soon)Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/readme.md) for usage.
|
||||
- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please refer to the [tutorial](../../deploy/pdserving/README.md) for usage.
|
||||
|
||||
# Service deployment based on PaddleHub Serving
|
||||
|
||||
|
|
|
@ -0,0 +1,158 @@
|
|||
# OCR Pipeline WebService
|
||||
|
||||
(English|[简体中文](./README_CN.md))
|
||||
|
||||
PaddleOCR provides two service deployment methods:
|
||||
- Based on **PaddleHub Serving**: Code path is "`./deploy/hubserving`". Please refer to the [tutorial](../../deploy/hubserving/readme_en.md)
|
||||
- Based on **PaddleServing**: Code path is "`./deploy/pdserving`". Please follow this tutorial.
|
||||
|
||||
# Service deployment based on PaddleServing
|
||||
|
||||
This document will introduce how to use the [PaddleServing](https://github.com/PaddlePaddle/Serving/blob/develop/README.md) to deploy the PPOCR dynamic graph model as a pipeline online service.
|
||||
|
||||
Some Key Features of Paddle Serving:
|
||||
- Integrate with Paddle training pipeline seamlessly, most paddle models can be deployed with one line command.
|
||||
- Industrial serving features supported, such as models management, online loading, online A/B testing etc.
|
||||
- Highly concurrent and efficient communication between clients and servers supported.
|
||||
|
||||
The introduction and tutorial of Paddle Serving service deployment framework reference [document](https://github.com/PaddlePaddle/Serving/blob/develop/README.md).
|
||||
|
||||
|
||||
## Contents
|
||||
- [Environmental preparation](#environmental-preparation)
|
||||
- [Model conversion](#model-conversion)
|
||||
- [Paddle Serving pipeline deployment](#paddle-serving-pipeline-deployment)
|
||||
- [FAQ](#faq)
|
||||
|
||||
<a name="environmental-preparation"></a>
|
||||
## Environmental preparation
|
||||
|
||||
PaddleOCR operating environment and Paddle Serving operating environment are needed.
|
||||
|
||||
1. Please prepare PaddleOCR operating environment reference [link](../../doc/doc_ch/installation.md).
|
||||
|
||||
2. The steps of PaddleServing operating environment prepare are as follows:
|
||||
|
||||
Install serving which used to start the service
|
||||
```
|
||||
pip3 install paddle-serving-server==0.5.0 # for CPU
|
||||
pip3 install paddle-serving-server-gpu==0.5.0 # for GPU
|
||||
# Other GPU environments need to confirm the environment and then choose to execute the following commands
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post9 # GPU with CUDA9.0
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post10 # GPU with CUDA10.0
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post101 # GPU with CUDA10.1 + TensorRT6
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + TensorRT7
|
||||
```
|
||||
|
||||
3. Install the client to send requests to the service
|
||||
```
|
||||
pip3 install paddle-serving-client==0.5.0 # for CPU
|
||||
|
||||
pip3 install paddle-serving-client-gpu==0.5.0 # for GPU
|
||||
```
|
||||
|
||||
4. Install serving-app
|
||||
```
|
||||
pip3 install paddle-serving-app==0.3.0
|
||||
# fix local_predict to support load dynamic model
|
||||
# find the install directoory of paddle_serving_app
|
||||
vim /usr/local/lib/python3.7/site-packages/paddle_serving_app/local_predict.py
|
||||
# replace line 85 of local_predict.py config = AnalysisConfig(model_path) with:
|
||||
if os.path.exists(os.path.join(model_path, "__params__")):
|
||||
config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__"))
|
||||
else:
|
||||
config = AnalysisConfig(model_path)
|
||||
```
|
||||
|
||||
**note:** If you want to install the latest version of PaddleServing, refer to [link](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md).
|
||||
|
||||
|
||||
<a name="model-conversion"></a>
|
||||
## Model conversion
|
||||
When using PaddleServing for service deployment, you need to convert the saved inference model into a serving model that is easy to deploy.
|
||||
|
||||
Firstly, download the [inference model](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15) of PPOCR
|
||||
```
|
||||
# Download and unzip the OCR text detection model
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar
|
||||
# Download and unzip the OCR text recognition model
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar
|
||||
|
||||
```
|
||||
Then, you can use installed paddle_serving_client tool to convert inference model to server model.
|
||||
```
|
||||
# Detection model conversion
|
||||
python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_det_infer/ \
|
||||
--model_filename inference.pdmodel \
|
||||
--params_filename inference.pdiparams \
|
||||
--serving_server ./ppocr_det_server_2.0_serving/ \
|
||||
--serving_client ./ppocr_det_server_2.0_client/
|
||||
|
||||
# Recognition model conversion
|
||||
python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_rec_infer/ \
|
||||
--model_filename inference.pdmodel \
|
||||
--params_filename inference.pdiparams \
|
||||
--serving_server ./ppocr_rec_server_2.0_serving/ \
|
||||
--serving_client ./ppocr_rec_server_2.0_client/
|
||||
|
||||
```
|
||||
|
||||
After the detection model is converted, there will be additional folders of `ppocr_det_server_2.0_serving` and `ppocr_det_server_2.0_client` in the current folder, with the following format:
|
||||
```
|
||||
|- ppocr_det_server_2.0_serving/
|
||||
|- __model__
|
||||
|- __params__
|
||||
|- serving_server_conf.prototxt
|
||||
|- serving_server_conf.stream.prototxt
|
||||
|
||||
|- ppocr_det_server_2.0_client
|
||||
|- serving_client_conf.prototxt
|
||||
|- serving_client_conf.stream.prototxt
|
||||
|
||||
```
|
||||
The recognition model is the same.
|
||||
|
||||
<a name="paddle-serving-pipeline-deployment"></a>
|
||||
## Paddle Serving pipeline deployment
|
||||
|
||||
1. Download the PaddleOCR code, if you have already downloaded it, you can skip this step.
|
||||
```
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
# Enter the working directory
|
||||
cd PaddleOCR/deploy/pdserver/
|
||||
```
|
||||
|
||||
The pdserver directory contains the code to start the pipeline service and send prediction requests, including:
|
||||
```
|
||||
__init__.py
|
||||
config.yml # Start the service configuration file
|
||||
ocr_reader.py # OCR model pre-processing and post-processing code implementation
|
||||
pipeline_http_client.py # Script to send pipeline prediction request
|
||||
web_service.py # Start the script of the pipeline server
|
||||
```
|
||||
|
||||
2. Run the following command to start the service.
|
||||
```
|
||||
# Start the service and save the running log in log.txt
|
||||
python3 web_service.py &>log.txt &
|
||||
```
|
||||
After the service is successfully started, a log similar to the following will be printed in log.txt
|
||||
![](./imgs/start_server.png)
|
||||
|
||||
3. Send service request
|
||||
```
|
||||
python3 pipeline_http_client.py
|
||||
```
|
||||
After successfully running, the predicted result of the model will be printed in the cmd window. An example of the result is:
|
||||
![](./imgs/results.png)
|
||||
|
||||
<a name="faq"></a>
|
||||
## FAQ
|
||||
**Q1**: No result return after sending the request.
|
||||
|
||||
**A1**: Do not set the proxy when starting the service and sending the request. You can close the proxy before starting the service and before sending the request. The command to close the proxy is:
|
||||
```
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
```
|
|
@ -0,0 +1,160 @@
|
|||
# PPOCR 服务化部署
|
||||
|
||||
([English](./README.md)|简体中文)
|
||||
|
||||
PaddleOCR提供2种服务部署方式:
|
||||
- 基于PaddleHub Serving的部署:代码路径为"`./deploy/hubserving`",使用方法参考[文档](../../deploy/hubserving/readme.md);
|
||||
- 基于PaddleServing的部署:代码路径为"`./deploy/pdserving`",按照本教程使用。
|
||||
|
||||
# 基于PaddleServing的服务部署
|
||||
|
||||
本文档将介绍如何使用[PaddleServing](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)工具部署PPOCR
|
||||
动态图模型的pipeline在线服务。
|
||||
|
||||
相比较于hubserving部署,PaddleServing具备以下优点:
|
||||
- 支持客户端和服务端之间高并发和高效通信
|
||||
- 支持 工业级的服务能力 例如模型管理,在线加载,在线A/B测试等
|
||||
- 支持 多种编程语言 开发客户端,例如C++, Python和Java
|
||||
|
||||
更多有关PaddleServing服务化部署框架介绍和使用教程参考[文档](https://github.com/PaddlePaddle/Serving/blob/develop/README_CN.md)。
|
||||
|
||||
## 目录
|
||||
- [环境准备](#环境准备)
|
||||
- [模型转换](#模型转换)
|
||||
- [Paddle Serving pipeline部署](#部署)
|
||||
- [FAQ](#FAQ)
|
||||
|
||||
<a name="环境准备"></a>
|
||||
## 环境准备
|
||||
|
||||
需要准备PaddleOCR的运行环境和Paddle Serving的运行环境。
|
||||
|
||||
- 准备PaddleOCR的运行环境参考[链接](../../doc/doc_ch/installation.md)
|
||||
|
||||
- 准备PaddleServing的运行环境,步骤如下
|
||||
|
||||
1. 安装serving,用于启动服务
|
||||
```
|
||||
pip3 install paddle-serving-server==0.5.0 # for CPU
|
||||
pip3 install paddle-serving-server-gpu==0.5.0 # for GPU
|
||||
# 其他GPU环境需要确认环境再选择执行如下命令
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post9 # GPU with CUDA9.0
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post10 # GPU with CUDA10.0
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post101 # GPU with CUDA10.1 + TensorRT6
|
||||
pip3 install paddle-serving-server-gpu==0.5.0.post11 # GPU with CUDA10.1 + TensorRT7
|
||||
```
|
||||
|
||||
2. 安装client,用于向服务发送请求
|
||||
```
|
||||
pip3 install paddle-serving-client==0.5.0 # for CPU
|
||||
|
||||
pip3 install paddle-serving-client-gpu==0.5.0 # for GPU
|
||||
```
|
||||
|
||||
3. 安装serving-app
|
||||
```
|
||||
pip3 install paddle-serving-app==0.3.0
|
||||
```
|
||||
**note:** 安装0.3.0版本的serving-app后,为了能加载动态图模型,需要修改serving_app的源码,具体为:
|
||||
```
|
||||
# 找到paddle_serving_app的安装目录,找到并编辑local_predict.py文件
|
||||
vim /usr/local/lib/python3.7/site-packages/paddle_serving_app/local_predict.py
|
||||
# 将local_predict.py 的第85行 config = AnalysisConfig(model_path) 替换为:
|
||||
if os.path.exists(os.path.join(model_path, "__params__")):
|
||||
config = AnalysisConfig(os.path.join(model_path, "__model__"), os.path.join(model_path, "__params__"))
|
||||
else:
|
||||
config = AnalysisConfig(model_path)
|
||||
```
|
||||
|
||||
**Note:** 如果要安装最新版本的PaddleServing参考[链接](https://github.com/PaddlePaddle/Serving/blob/develop/doc/LATEST_PACKAGES.md)。
|
||||
|
||||
<a name="模型转换"></a>
|
||||
## 模型转换
|
||||
|
||||
使用PaddleServing做服务化部署时,需要将保存的inference模型转换为serving易于部署的模型。
|
||||
|
||||
首先,下载PPOCR的[inference模型](https://github.com/PaddlePaddle/PaddleOCR#pp-ocr-20-series-model-listupdate-on-dec-15)
|
||||
```
|
||||
# 下载并解压 OCR 文本检测模型
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar
|
||||
# 下载并解压 OCR 文本识别模型
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar
|
||||
```
|
||||
|
||||
接下来,用安装的paddle_serving_client把下载的inference模型转换成易于server部署的模型格式。
|
||||
|
||||
```
|
||||
# 转换检测模型
|
||||
python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_det_infer/ \
|
||||
--model_filename inference.pdmodel \
|
||||
--params_filename inference.pdiparams \
|
||||
--serving_server ./ppocr_det_server_2.0_serving/ \
|
||||
--serving_client ./ppocr_det_server_2.0_client/
|
||||
|
||||
# 转换识别模型
|
||||
python3 -m paddle_serving_client.convert --dirname ./ch_ppocr_server_v2.0_rec_infer/ \
|
||||
--model_filename inference.pdmodel \
|
||||
--params_filename inference.pdiparams \
|
||||
--serving_server ./ppocr_rec_server_2.0_serving/ \
|
||||
--serving_client ./ppocr_rec_server_2.0_client/
|
||||
```
|
||||
|
||||
检测模型转换完成后,会在当前文件夹多出`ppocr_det_server_2.0_serving` 和`ppocr_det_server_2.0_client`的文件夹,具备如下格式:
|
||||
```
|
||||
|- ppocr_det_server_2.0_serving/
|
||||
|- __model__
|
||||
|- __params__
|
||||
|- serving_server_conf.prototxt
|
||||
|- serving_server_conf.stream.prototxt
|
||||
|
||||
|- ppocr_det_server_2.0_client
|
||||
|- serving_client_conf.prototxt
|
||||
|- serving_client_conf.stream.prototxt
|
||||
|
||||
```
|
||||
识别模型同理。
|
||||
|
||||
<a name="部署"></a>
|
||||
## Paddle Serving pipeline部署
|
||||
|
||||
1. 下载PaddleOCR代码,若已下载可跳过此步骤
|
||||
```
|
||||
git clone https://github.com/PaddlePaddle/PaddleOCR
|
||||
|
||||
# 进入到工作目录
|
||||
cd PaddleOCR/deploy/pdserver/
|
||||
```
|
||||
pdserver目录包含启动pipeline服务和发送预测请求的代码,包括:
|
||||
```
|
||||
__init__.py
|
||||
config.yml # 启动服务的配置文件
|
||||
ocr_reader.py # OCR模型预处理和后处理的代码实现
|
||||
pipeline_http_client.py # 发送pipeline预测请求的脚本
|
||||
web_service.py # 启动pipeline服务端的脚本
|
||||
```
|
||||
|
||||
2. 启动服务可运行如下命令:
|
||||
```
|
||||
# 启动服务,运行日志保存在log.txt
|
||||
python3 web_service.py &>log.txt &
|
||||
```
|
||||
成功启动服务后,log.txt中会打印类似如下日志
|
||||
![](./imgs/start_server.png)
|
||||
|
||||
3. 发送服务请求:
|
||||
```
|
||||
python3 pipeline_http_client.py
|
||||
```
|
||||
成功运行后,模型预测的结果会打印在cmd窗口中,结果示例为:
|
||||
![](./imgs/results.png)
|
||||
|
||||
|
||||
<a name="FAQ"></a>
|
||||
## FAQ
|
||||
**Q1**: 发送请求后没有结果返回或者提示输出解码报错
|
||||
|
||||
**A1**: 启动服务和发送请求时不要设置代理,可以在启动服务前和发送请求前关闭代理,关闭代理的命令是:
|
||||
```
|
||||
unset https_proxy
|
||||
unset http_proxy
|
||||
```
|
|
@ -0,0 +1,13 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
|
@ -0,0 +1,71 @@
|
|||
#rpc端口, rpc_port和http_port不允许同时为空。当rpc_port为空且http_port不为空时,会自动将rpc_port设置为http_port+1
|
||||
rpc_port: 18090
|
||||
|
||||
#http端口, rpc_port和http_port不允许同时为空。当rpc_port可用且http_port为空时,不自动生成http_port
|
||||
http_port: 9999
|
||||
|
||||
#worker_num, 最大并发数。当build_dag_each_worker=True时, 框架会创建worker_num个进程,每个进程内构建grpcSever和DAG
|
||||
##当build_dag_each_worker=False时,框架会设置主线程grpc线程池的max_workers=worker_num
|
||||
worker_num: 20
|
||||
|
||||
#build_dag_each_worker, False,框架在进程内创建一条DAG;True,框架会每个进程内创建多个独立的DAG
|
||||
build_dag_each_worker: false
|
||||
|
||||
dag:
|
||||
#op资源类型, True, 为线程模型;False,为进程模型
|
||||
is_thread_op: False
|
||||
|
||||
#重试次数
|
||||
retry: 1
|
||||
|
||||
#使用性能分析, True,生成Timeline性能数据,对性能有一定影响;False为不使用
|
||||
use_profile: False
|
||||
|
||||
tracer:
|
||||
interval_s: 10
|
||||
op:
|
||||
det:
|
||||
#并发数,is_thread_op=True时,为线程并发;否则为进程并发
|
||||
concurrency: 4
|
||||
|
||||
#当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
|
||||
local_service_conf:
|
||||
#client类型,包括brpc, grpc和local_predictor.local_predictor不启动Serving服务,进程内预测
|
||||
client_type: local_predictor
|
||||
|
||||
#det模型路径
|
||||
model_config: /paddle/serving/models/det_serving_server/ #ocr_det_model
|
||||
|
||||
#Fetch结果列表,以client_config中fetch_var的alias_name为准
|
||||
fetch_list: ["save_infer_model/scale_0.tmp_1"]
|
||||
|
||||
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
|
||||
devices: "2"
|
||||
|
||||
ir_optim: True
|
||||
rec:
|
||||
#并发数,is_thread_op=True时,为线程并发;否则为进程并发
|
||||
concurrency: 1
|
||||
|
||||
#超时时间, 单位ms
|
||||
timeout: -1
|
||||
|
||||
#Serving交互重试次数,默认不重试
|
||||
retry: 1
|
||||
|
||||
#当op配置没有server_endpoints时,从local_service_conf读取本地服务配置
|
||||
local_service_conf:
|
||||
|
||||
#client类型,包括brpc, grpc和local_predictor。local_predictor不启动Serving服务,进程内预测
|
||||
client_type: local_predictor
|
||||
|
||||
#rec模型路径
|
||||
model_config: /paddle/serving/models/rec_serving_server/ #ocr_rec_model
|
||||
|
||||
#Fetch结果列表,以client_config中fetch_var的alias_name为准
|
||||
fetch_list: ["save_infer_model/scale_0.tmp_1"] #["ctc_greedy_decoder_0.tmp_0", "softmax_0.tmp_0"]
|
||||
|
||||
#计算硬件ID,当devices为""或不写时为CPU预测;当devices为"0", "0,1,2"时为GPU预测,表示使用的GPU卡
|
||||
devices: "2"
|
||||
|
||||
ir_optim: True
|
Before Width: | Height: | Size: 26 KiB After Width: | Height: | Size: 26 KiB |
Before Width: | Height: | Size: 998 KiB After Width: | Height: | Size: 998 KiB |
After Width: | Height: | Size: 119 KiB |
After Width: | Height: | Size: 195 KiB |
|
@ -0,0 +1,438 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import cv2
|
||||
import copy
|
||||
import numpy as np
|
||||
import math
|
||||
import re
|
||||
import sys
|
||||
import argparse
|
||||
import string
|
||||
from copy import deepcopy
|
||||
import paddle
|
||||
|
||||
|
||||
class DetResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
if 'image_shape' in kwargs:
|
||||
self.image_shape = kwargs['image_shape']
|
||||
self.resize_type = 1
|
||||
elif 'limit_side_len' in kwargs:
|
||||
self.limit_side_len = kwargs['limit_side_len']
|
||||
self.limit_type = kwargs.get('limit_type', 'min')
|
||||
elif 'resize_long' in kwargs:
|
||||
self.resize_type = 2
|
||||
self.resize_long = kwargs.get('resize_long', 960)
|
||||
else:
|
||||
self.limit_side_len = 736
|
||||
self.limit_type = 'min'
|
||||
|
||||
def __call__(self, data):
|
||||
img = deepcopy(data)
|
||||
src_h, src_w, _ = img.shape
|
||||
|
||||
if self.resize_type == 0:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type0(img)
|
||||
elif self.resize_type == 2:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type2(img)
|
||||
else:
|
||||
img, [ratio_h, ratio_w] = self.resize_image_type1(img)
|
||||
|
||||
return img
|
||||
|
||||
def resize_image_type1(self, img):
|
||||
resize_h, resize_w = self.image_shape
|
||||
ori_h, ori_w = img.shape[:2] # (h, w, c)
|
||||
ratio_h = float(resize_h) / ori_h
|
||||
ratio_w = float(resize_w) / ori_w
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type0(self, img):
|
||||
"""
|
||||
resize image to a size multiple of 32 which is required by the network
|
||||
args:
|
||||
img(array): array with shape [h, w, c]
|
||||
return(tuple):
|
||||
img, (ratio_h, ratio_w)
|
||||
"""
|
||||
limit_side_len = self.limit_side_len
|
||||
h, w, _ = img.shape
|
||||
|
||||
# limit the max side
|
||||
if self.limit_type == 'max':
|
||||
if max(h, w) > limit_side_len:
|
||||
if h > w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
else:
|
||||
if min(h, w) < limit_side_len:
|
||||
if h < w:
|
||||
ratio = float(limit_side_len) / h
|
||||
else:
|
||||
ratio = float(limit_side_len) / w
|
||||
else:
|
||||
ratio = 1.
|
||||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = int(round(resize_h / 32) * 32)
|
||||
resize_w = int(round(resize_w / 32) * 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
return None, (None, None)
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
except:
|
||||
print(img.shape, resize_w, resize_h)
|
||||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
# return img, np.array([h, w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
h, w, _ = img.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
ratio = float(self.resize_long) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
img = cv2.resize(img, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, config):
|
||||
support_character_type = [
|
||||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||
'ne', 'EN'
|
||||
]
|
||||
character_type = config['character_type']
|
||||
character_dict_path = config['character_dict_path']
|
||||
use_space_char = True
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
|
||||
if character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type == "EN_symbol":
|
||||
# same with ASTER setting (use 94 char).
|
||||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
elif character_type in support_character_type:
|
||||
self.character_str = ""
|
||||
assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
|
||||
character_type)
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
if use_space_char:
|
||||
self.character_str += " "
|
||||
dict_character = list(self.character_str)
|
||||
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.character_type = character_type
|
||||
dict_character = self.add_special_char(dict_character)
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
return dict_character
|
||||
|
||||
def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
result_list = []
|
||||
ignored_tokens = self.get_ignored_tokens()
|
||||
batch_size = len(text_index)
|
||||
for batch_idx in range(batch_size):
|
||||
char_list = []
|
||||
conf_list = []
|
||||
for idx in range(len(text_index[batch_idx])):
|
||||
if text_index[batch_idx][idx] in ignored_tokens:
|
||||
continue
|
||||
if is_remove_duplicate:
|
||||
# only for predict
|
||||
if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
|
||||
batch_idx][idx]:
|
||||
continue
|
||||
char_list.append(self.character[int(text_index[batch_idx][
|
||||
idx])])
|
||||
if text_prob is not None:
|
||||
conf_list.append(text_prob[batch_idx][idx])
|
||||
else:
|
||||
conf_list.append(1)
|
||||
text = ''.join(char_list)
|
||||
result_list.append((text, np.mean(conf_list)))
|
||||
return result_list
|
||||
|
||||
def get_ignored_tokens(self):
|
||||
return [0] # for ctc blank
|
||||
|
||||
|
||||
class CTCLabelDecode(BaseRecLabelDecode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
#character_dict_path=None,
|
||||
#character_type='ch',
|
||||
#use_space_char=False,
|
||||
**kwargs):
|
||||
super(CTCLabelDecode, self).__init__(config)
|
||||
|
||||
def __call__(self, preds, label=None, *args, **kwargs):
|
||||
if isinstance(preds, paddle.Tensor):
|
||||
preds = preds.numpy()
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
if label is None:
|
||||
return text
|
||||
label = self.decode(label)
|
||||
return text, label
|
||||
|
||||
def add_special_char(self, dict_character):
|
||||
dict_character = ['blank'] + dict_character
|
||||
return dict_character
|
||||
|
||||
|
||||
class CharacterOps(object):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, config):
|
||||
self.character_type = config['character_type']
|
||||
self.loss_type = config['loss_type']
|
||||
if self.character_type == "en":
|
||||
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||||
dict_character = list(self.character_str)
|
||||
elif self.character_type == "ch":
|
||||
character_dict_path = config['character_dict_path']
|
||||
self.character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
self.character_str += line
|
||||
dict_character = list(self.character_str)
|
||||
elif self.character_type == "en_sensitive":
|
||||
# same with ASTER setting (use 94 char).
|
||||
self.character_str = string.printable[:-6]
|
||||
dict_character = list(self.character_str)
|
||||
else:
|
||||
self.character_str = None
|
||||
assert self.character_str is not None, \
|
||||
"Nonsupport type of the character: {}".format(self.character_str)
|
||||
self.beg_str = "sos"
|
||||
self.end_str = "eos"
|
||||
if self.loss_type == "attention":
|
||||
dict_character = [self.beg_str, self.end_str] + dict_character
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
self.character = dict_character
|
||||
|
||||
def encode(self, text):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
||||
output:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
if self.character_type == "en":
|
||||
text = text.lower()
|
||||
|
||||
text_list = []
|
||||
for char in text:
|
||||
if char not in self.dict:
|
||||
continue
|
||||
text_list.append(self.dict[char])
|
||||
text = np.array(text_list)
|
||||
return text
|
||||
|
||||
def decode(self, text_index, is_remove_duplicate=False):
|
||||
""" convert text-index into text-label. """
|
||||
char_list = []
|
||||
char_num = self.get_char_num()
|
||||
|
||||
if self.loss_type == "attention":
|
||||
beg_idx = self.get_beg_end_flag_idx("beg")
|
||||
end_idx = self.get_beg_end_flag_idx("end")
|
||||
ignored_tokens = [beg_idx, end_idx]
|
||||
else:
|
||||
ignored_tokens = [char_num]
|
||||
|
||||
for idx in range(len(text_index)):
|
||||
if text_index[idx] in ignored_tokens:
|
||||
continue
|
||||
if is_remove_duplicate:
|
||||
if idx > 0 and text_index[idx - 1] == text_index[idx]:
|
||||
continue
|
||||
char_list.append(self.character[text_index[idx]])
|
||||
text = ''.join(char_list)
|
||||
return text
|
||||
|
||||
def get_char_num(self):
|
||||
return len(self.character)
|
||||
|
||||
def get_beg_end_flag_idx(self, beg_or_end):
|
||||
if self.loss_type == "attention":
|
||||
if beg_or_end == "beg":
|
||||
idx = np.array(self.dict[self.beg_str])
|
||||
elif beg_or_end == "end":
|
||||
idx = np.array(self.dict[self.end_str])
|
||||
else:
|
||||
assert False, "Unsupport type %s in get_beg_end_flag_idx"\
|
||||
% beg_or_end
|
||||
return idx
|
||||
else:
|
||||
err = "error in get_beg_end_flag_idx when using the loss %s"\
|
||||
% (self.loss_type)
|
||||
assert False, err
|
||||
|
||||
|
||||
class OCRReader(object):
|
||||
def __init__(self,
|
||||
algorithm="CRNN",
|
||||
image_shape=[3, 32, 320],
|
||||
char_type="ch",
|
||||
batch_num=1,
|
||||
char_dict_path="./ppocr_keys_v1.txt"):
|
||||
self.rec_image_shape = image_shape
|
||||
self.character_type = char_type
|
||||
self.rec_batch_num = batch_num
|
||||
char_ops_params = {}
|
||||
char_ops_params["character_type"] = char_type
|
||||
char_ops_params["character_dict_path"] = char_dict_path
|
||||
char_ops_params['loss_type'] = 'ctc'
|
||||
self.char_ops = CharacterOps(char_ops_params)
|
||||
self.label_ops = CTCLabelDecode(char_ops_params)
|
||||
|
||||
def resize_norm_img(self, img, max_wh_ratio):
|
||||
imgC, imgH, imgW = self.rec_image_shape
|
||||
if self.character_type == "ch":
|
||||
imgW = int(32 * max_wh_ratio)
|
||||
h = img.shape[0]
|
||||
w = img.shape[1]
|
||||
ratio = w / float(h)
|
||||
if math.ceil(imgH * ratio) > imgW:
|
||||
resized_w = imgW
|
||||
else:
|
||||
resized_w = int(math.ceil(imgH * ratio))
|
||||
resized_image = cv2.resize(img, (resized_w, imgH))
|
||||
resized_image = resized_image.astype('float32')
|
||||
resized_image = resized_image.transpose((2, 0, 1)) / 255
|
||||
resized_image -= 0.5
|
||||
resized_image /= 0.5
|
||||
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
|
||||
|
||||
padding_im[:, :, 0:resized_w] = resized_image
|
||||
return padding_im
|
||||
|
||||
def preprocess(self, img_list):
|
||||
img_num = len(img_list)
|
||||
norm_img_batch = []
|
||||
max_wh_ratio = 0
|
||||
for ino in range(img_num):
|
||||
h, w = img_list[ino].shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
|
||||
for ino in range(img_num):
|
||||
norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
|
||||
norm_img = norm_img[np.newaxis, :]
|
||||
norm_img_batch.append(norm_img)
|
||||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
|
||||
return norm_img_batch[0]
|
||||
|
||||
def postprocess_old(self, outputs, with_score=False):
|
||||
rec_res = []
|
||||
rec_idx_lod = outputs["ctc_greedy_decoder_0.tmp_0.lod"]
|
||||
rec_idx_batch = outputs["ctc_greedy_decoder_0.tmp_0"]
|
||||
if with_score:
|
||||
predict_lod = outputs["softmax_0.tmp_0.lod"]
|
||||
for rno in range(len(rec_idx_lod) - 1):
|
||||
beg = rec_idx_lod[rno]
|
||||
end = rec_idx_lod[rno + 1]
|
||||
if isinstance(rec_idx_batch, list):
|
||||
rec_idx_tmp = [x[0] for x in rec_idx_batch[beg:end]]
|
||||
else: #nd array
|
||||
rec_idx_tmp = rec_idx_batch[beg:end, 0]
|
||||
preds_text = self.char_ops.decode(rec_idx_tmp)
|
||||
if with_score:
|
||||
beg = predict_lod[rno]
|
||||
end = predict_lod[rno + 1]
|
||||
if isinstance(outputs["softmax_0.tmp_0"], list):
|
||||
outputs["softmax_0.tmp_0"] = np.array(outputs[
|
||||
"softmax_0.tmp_0"]).astype(np.float32)
|
||||
probs = outputs["softmax_0.tmp_0"][beg:end, :]
|
||||
ind = np.argmax(probs, axis=1)
|
||||
blank = probs.shape[1]
|
||||
valid_ind = np.where(ind != (blank - 1))[0]
|
||||
score = np.mean(probs[valid_ind, ind[valid_ind]])
|
||||
rec_res.append([preds_text, score])
|
||||
else:
|
||||
rec_res.append([preds_text])
|
||||
return rec_res
|
||||
|
||||
def postprocess(self, outputs, with_score=False):
|
||||
preds = outputs["save_infer_model/scale_0.tmp_1"]
|
||||
try:
|
||||
preds = preds.numpy()
|
||||
except:
|
||||
pass
|
||||
preds_idx = preds.argmax(axis=2)
|
||||
preds_prob = preds.max(axis=2)
|
||||
text = self.label_ops.decode(
|
||||
preds_idx, preds_prob, is_remove_duplicate=True)
|
||||
return text
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
import json
|
||||
import base64
|
||||
import os
|
||||
|
||||
|
||||
def cv2_to_base64(image):
|
||||
return base64.b64encode(image).decode('utf8')
|
||||
|
||||
|
||||
url = "http://127.0.0.1:9999/ocr/prediction"
|
||||
test_img_dir = "../doc/imgs/"
|
||||
for idx, img_file in enumerate(os.listdir(test_img_dir)):
|
||||
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
|
||||
image_data1 = file.read()
|
||||
|
||||
image = cv2_to_base64(image_data1)
|
||||
|
||||
for i in range(1):
|
||||
data = {"key": ["image"], "value": [image]}
|
||||
r = requests.post(url=url, data=json.dumps(data))
|
||||
print(r.json())
|
||||
|
||||
test_img_dir = "../doc/imgs/"
|
||||
print("==> total number of test imgs: ", len(os.listdir(test_img_dir)))
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
from paddle_serving_server_gpu.pipeline import PipelineClient
|
||||
except ImportError:
|
||||
from paddle_serving_server.pipeline import PipelineClient
|
||||
import numpy as np
|
||||
import requests
|
||||
import json
|
||||
import cv2
|
||||
import base64
|
||||
import os
|
||||
|
||||
client = PipelineClient()
|
||||
client.connect(['127.0.0.1:18090'])
|
||||
|
||||
|
||||
def cv2_to_base64(image):
|
||||
return base64.b64encode(image).decode('utf8')
|
||||
|
||||
|
||||
test_img_dir = "imgs/"
|
||||
for img_file in os.listdir(test_img_dir):
|
||||
with open(os.path.join(test_img_dir, img_file), 'rb') as file:
|
||||
image_data = file.read()
|
||||
image = cv2_to_base64(image_data)
|
||||
|
||||
for i in range(1):
|
||||
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
|
||||
print(ret)
|
||||
#print(ret)
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
try:
|
||||
from paddle_serving_server_gpu.web_service import WebService, Op
|
||||
except ImportError:
|
||||
from paddle_serving_server.web_service import WebService, Op
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import cv2
|
||||
import base64
|
||||
# from paddle_serving_app.reader import OCRReader
|
||||
from ocr_reader import OCRReader, DetResizeForTest
|
||||
from paddle_serving_app.reader import Sequential, ResizeByFactor
|
||||
from paddle_serving_app.reader import Div, Normalize, Transpose
|
||||
from paddle_serving_app.reader import DBPostProcess, FilterBoxes, GetRotateCropImage, SortedBoxes
|
||||
|
||||
_LOGGER = logging.getLogger()
|
||||
|
||||
|
||||
class DetOp(Op):
|
||||
def init_op(self):
|
||||
self.det_preprocess = Sequential([
|
||||
DetResizeForTest(), Div(255),
|
||||
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), Transpose(
|
||||
(2, 0, 1))
|
||||
])
|
||||
self.filter_func = FilterBoxes(10, 10)
|
||||
self.post_func = DBPostProcess({
|
||||
"thresh": 0.3,
|
||||
"box_thresh": 0.5,
|
||||
"max_candidates": 1000,
|
||||
"unclip_ratio": 1.5,
|
||||
"min_size": 3
|
||||
})
|
||||
|
||||
def preprocess(self, input_dicts, data_id, log_id):
|
||||
(_, input_dict), = input_dicts.items()
|
||||
data = base64.b64decode(input_dict["image"].encode('utf8'))
|
||||
data = np.fromstring(data, np.uint8)
|
||||
# Note: class variables(self.var) can only be used in process op mode
|
||||
im = cv2.imdecode(data, cv2.IMREAD_COLOR)
|
||||
self.im = im
|
||||
self.ori_h, self.ori_w, _ = im.shape
|
||||
|
||||
det_img = self.det_preprocess(self.im)
|
||||
_, self.new_h, self.new_w = det_img.shape
|
||||
print("det image shape", det_img.shape)
|
||||
return {"x": det_img[np.newaxis, :].copy()}, False, None, ""
|
||||
|
||||
def postprocess(self, input_dicts, fetch_dict, log_id):
|
||||
print("input_dicts: ", input_dicts)
|
||||
det_out = fetch_dict["save_infer_model/scale_0.tmp_1"]
|
||||
ratio_list = [
|
||||
float(self.new_h) / self.ori_h, float(self.new_w) / self.ori_w
|
||||
]
|
||||
dt_boxes_list = self.post_func(det_out, [ratio_list])
|
||||
dt_boxes = self.filter_func(dt_boxes_list[0], [self.ori_h, self.ori_w])
|
||||
out_dict = {"dt_boxes": dt_boxes, "image": self.im}
|
||||
|
||||
print("out dict", out_dict["dt_boxes"])
|
||||
return out_dict, None, ""
|
||||
|
||||
|
||||
class RecOp(Op):
|
||||
def init_op(self):
|
||||
self.ocr_reader = OCRReader(
|
||||
char_dict_path="../../ppocr/utils/ppocr_keys_v1.txt")
|
||||
|
||||
self.get_rotate_crop_image = GetRotateCropImage()
|
||||
self.sorted_boxes = SortedBoxes()
|
||||
|
||||
def preprocess(self, input_dicts, data_id, log_id):
|
||||
(_, input_dict), = input_dicts.items()
|
||||
im = input_dict["image"]
|
||||
dt_boxes = input_dict["dt_boxes"]
|
||||
dt_boxes = self.sorted_boxes(dt_boxes)
|
||||
feed_list = []
|
||||
img_list = []
|
||||
max_wh_ratio = 0
|
||||
for i, dtbox in enumerate(dt_boxes):
|
||||
boximg = self.get_rotate_crop_image(im, dt_boxes[i])
|
||||
img_list.append(boximg)
|
||||
h, w = boximg.shape[0:2]
|
||||
wh_ratio = w * 1.0 / h
|
||||
max_wh_ratio = max(max_wh_ratio, wh_ratio)
|
||||
_, w, h = self.ocr_reader.resize_norm_img(img_list[0],
|
||||
max_wh_ratio).shape
|
||||
|
||||
imgs = np.zeros((len(img_list), 3, w, h)).astype('float32')
|
||||
for id, img in enumerate(img_list):
|
||||
norm_img = self.ocr_reader.resize_norm_img(img, max_wh_ratio)
|
||||
imgs[id] = norm_img
|
||||
print("rec image shape", imgs.shape)
|
||||
feed = {"x": imgs.copy()}
|
||||
return feed, False, None, ""
|
||||
|
||||
def postprocess(self, input_dicts, fetch_dict, log_id):
|
||||
rec_res = self.ocr_reader.postprocess(fetch_dict, with_score=True)
|
||||
res_lst = []
|
||||
for res in rec_res:
|
||||
res_lst.append(res[0])
|
||||
res = {"res": str(res_lst)}
|
||||
return res, None, ""
|
||||
|
||||
|
||||
class OcrService(WebService):
|
||||
def get_pipeline_response(self, read_op):
|
||||
det_op = DetOp(name="det", input_ops=[read_op])
|
||||
rec_op = RecOp(name="rec", input_ops=[det_op])
|
||||
return rec_op
|
||||
|
||||
|
||||
uci_service = OcrService(name="ocr")
|
||||
uci_service.prepare_pipeline_config("config.yml")
|
||||
uci_service.run_service()
|
|
@ -0,0 +1,64 @@
|
|||
|
||||
## 介绍
|
||||
|
||||
复杂的模型有利于提高模型的性能,但也导致模型中存在一定冗余,模型裁剪通过移出网络模型中的子模型来减少这种冗余,达到减少模型计算复杂度,提高模型推理性能的目的。
|
||||
本教程将介绍如何使用飞桨模型压缩库PaddleSlim做PaddleOCR模型的压缩。
|
||||
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)集成了模型剪枝、量化(包括量化训练和离线量化)、蒸馏和神经网络搜索等多种业界常用且领先的模型压缩功能,如果您感兴趣,可以关注并了解。
|
||||
|
||||
|
||||
在开始本教程之前,建议先了解:
|
||||
1. [PaddleOCR模型的训练方法](../../../doc/doc_ch/quickstart.md)
|
||||
2. [模型裁剪教程](https://github.com/PaddlePaddle/PaddleSlim/blob/release%2F2.0.0/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md)
|
||||
|
||||
|
||||
## 快速开始
|
||||
|
||||
模型裁剪主要包括四个步骤:
|
||||
1. 安装 PaddleSlim
|
||||
2. 准备训练好的模型
|
||||
3. 敏感度分析、裁剪训练
|
||||
4. 导出模型、预测部署
|
||||
|
||||
### 1. 安装PaddleSlim
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
||||
git checkout develop
|
||||
cd Paddleslim
|
||||
python3 setup.py install
|
||||
```
|
||||
|
||||
### 2. 获取预训练模型
|
||||
模型裁剪需要加载事先训练好的模型,PaddleOCR也提供了一系列(模型)[../../../doc/doc_ch/models_list.md],开发者可根据需要自行选择模型或使用自己的模型。
|
||||
|
||||
### 3. 敏感度分析训练
|
||||
|
||||
加载预训练模型后,通过对现有模型的每个网络层进行敏感度分析,得到敏感度文件:sen.pickle,可以通过PaddleSlim提供的[接口](https://github.com/PaddlePaddle/PaddleSlim/blob/9b01b195f0c4bc34a1ab434751cb260e13d64d9e/paddleslim/dygraph/prune/filter_pruner.py#L75)加载文件,获得各网络层在不同裁剪比例下的精度损失。从而了解各网络层冗余度,决定每个网络层的裁剪比例。
|
||||
敏感度文件内容格式:
|
||||
sen.pickle(Dict){
|
||||
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
|
||||
'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
|
||||
}
|
||||
|
||||
例子:
|
||||
{
|
||||
'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
|
||||
'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
|
||||
}
|
||||
加载敏感度文件后会返回一个字典,字典中的keys为网络模型参数模型的名字,values为一个字典,里面保存了相应网络层的裁剪敏感度信息。例如在例子中,conv10_expand_weights所对应的网络层在裁掉10%的卷积核后模型性能相较原模型会下降0.65%,详细信息可见[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
|
||||
|
||||
进入PaddleOCR根目录,通过以下命令对模型进行敏感度分析训练:
|
||||
```bash
|
||||
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights="your trained model"
|
||||
```
|
||||
|
||||
### 4. 导出模型、预测部署
|
||||
|
||||
在得到裁剪训练保存的模型后,我们可以将其导出为inference_model:
|
||||
```bash
|
||||
pytho3.7 deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.save_inference_dir=inference_model
|
||||
```
|
||||
|
||||
inference model的预测和部署参考:
|
||||
1. [inference model python端预测](../../../doc/doc_ch/inference.md)
|
||||
2. [inference model C++预测](../../cpp_infer/readme.md)
|
|
@ -0,0 +1,71 @@
|
|||
|
||||
## Introduction
|
||||
|
||||
Generally, a more complex model would achive better performance in the task, but it also leads to some redundancy in the model. Model Pruning is a technique that reduces this redundancy by removing the sub-models in the neural network model, so as to reduce model calculation complexity and improve model inference performance.
|
||||
|
||||
This example uses PaddleSlim provided[APIs of Pruning](https://paddlepaddle.github.io/PaddleSlim/api/prune_api/) to compress the OCR model.
|
||||
[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim), an open source library which integrates model pruning, quantization (including quantization training and offline quantization), distillation, neural network architecture search, and many other commonly used and leading model compression technique in the industry.
|
||||
|
||||
It is recommended that you could understand following pages before reading this example:
|
||||
1. [PaddleOCR training methods](../../../doc/doc_ch/quickstart.md)
|
||||
2. [The demo of prune](https://github.com/PaddlePaddle/PaddleSlim/blob/release%2F2.0.0/docs/zh_cn/tutorials/pruning/dygraph/filter_pruning.md)
|
||||
|
||||
## Quick start
|
||||
|
||||
Five steps for OCR model prune:
|
||||
1. Install PaddleSlim
|
||||
2. Prepare the trained model
|
||||
3. Sensitivity analysis and tailoring training
|
||||
4. Export model, predict deployment
|
||||
|
||||
### 1. Install PaddleSlim
|
||||
|
||||
```bash
|
||||
git clone https://github.com/PaddlePaddle/PaddleSlim.git
|
||||
git checkout develop
|
||||
cd Paddleslim
|
||||
python3 setup.py install
|
||||
```
|
||||
|
||||
|
||||
### 2. Download Pretrain Model
|
||||
Model prune needs to load pre-trained models.
|
||||
PaddleOCR also provides a series of (models)[../../../doc/doc_en/models_list_en.md]. Developers can choose their own models or use their own models according to their needs.
|
||||
|
||||
|
||||
### 3. Pruning sensitivity analysis
|
||||
|
||||
After the pre-training model is loaded, sensitivity analysis is performed on each network layer of the model to understand the redundancy of each network layer, and save a sensitivity file which named: sen.pickle. After that, user could load the sensitivity file via the [methods provided by PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L221) and determining the pruning ratio of each network layer automatically. For specific details of sensitivity analysis, see:[Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/image_classification_sensitivity_analysis_tutorial.md)
|
||||
The data format of sensitivity file:
|
||||
sen.pickle(Dict){
|
||||
'layer_weight_name_0': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
|
||||
'layer_weight_name_1': sens_of_each_ratio(Dict){'pruning_ratio_0': acc_loss, 'pruning_ratio_1': acc_loss}
|
||||
}
|
||||
|
||||
example:
|
||||
{
|
||||
'conv10_expand_weights': {0.1: 0.006509952684312718, 0.2: 0.01827734339798862, 0.3: 0.014528405644659832, 0.6: 0.06536008804270439, 0.8: 0.11798612250664964, 0.7: 0.12391408417493704, 0.4: 0.030615754498018757, 0.5: 0.047105205602406594}
|
||||
'conv10_linear_weights': {0.1: 0.05113190831455035, 0.2: 0.07705573833558801, 0.3: 0.12096721757739311, 0.6: 0.5135061352930738, 0.8: 0.7908166677143281, 0.7: 0.7272187676899062, 0.4: 0.1819252083008504, 0.5: 0.3728054727792405}
|
||||
}
|
||||
The function would return a dict after loading the sensitivity file. The keys of the dict are name of parameters in each layer. And the value of key is the information about pruning sensitivity of correspoding layer. In example, pruning 10% filter of the layer corresponding to conv10_expand_weights would lead to 0.65% degradation of model performance. The details could be seen at: [Sensitivity analysis](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/algo/algo.md#2-%E5%8D%B7%E7%A7%AF%E6%A0%B8%E5%89%AA%E8%A3%81%E5%8E%9F%E7%90%86)
|
||||
|
||||
|
||||
Enter the PaddleOCR root directory,perform sensitivity analysis on the model with the following command:
|
||||
|
||||
```bash
|
||||
|
||||
python3.7 deploy/slim/prune/sensitivity_anal.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights="your trained model"
|
||||
|
||||
```
|
||||
|
||||
|
||||
### 5. Export inference model and deploy it
|
||||
|
||||
We can export the pruned model as inference_model for deployment:
|
||||
```bash
|
||||
python deploy/slim/prune/export_prune_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o Global.pretrain_weights=./output/det_db/best_accuracy Global.test_batch_size_per_card=1 Global.save_inference_dir=inference_model
|
||||
```
|
||||
|
||||
Reference for prediction and deployment of inference model:
|
||||
1. [inference model python prediction](../../../doc/doc_en/inference_en.md)
|
||||
2. [inference model C++ prediction](../../cpp_infer/readme_en.md)
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import paddle
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
import tools.program as program
|
||||
|
||||
|
||||
def main(config, device, logger, vdl_writer):
|
||||
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs before pruning: {flops}")
|
||||
|
||||
from paddleslim.dygraph import FPGMFilterPruner
|
||||
model.train()
|
||||
pruner = FPGMFilterPruner(model, [1, 3, 640, 640])
|
||||
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
def eval_fn():
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
return metric['hmean']
|
||||
|
||||
params_sensitive = pruner.sensitive(
|
||||
eval_func=eval_fn,
|
||||
sen_file="./sen.pickle",
|
||||
skip_vars=[
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
|
||||
])
|
||||
|
||||
logger.info(
|
||||
"The sensitivity analysis results of model parameters saved in sen.pickle"
|
||||
)
|
||||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info(f"{key}, {params_sensitive[key]}")
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs after pruning: {flops}")
|
||||
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, None)
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
|
||||
# start export model
|
||||
from paddle.jit import to_static
|
||||
|
||||
infer_shape = [3, -1, -1]
|
||||
if config['Architecture']['model_type'] == "rec":
|
||||
infer_shape = [3, 32, -1] # for rec model, H must be 32
|
||||
|
||||
if 'Transform' in config['Architecture'] and config['Architecture'][
|
||||
'Transform'] is not None and config['Architecture'][
|
||||
'Transform']['name'] == 'TPS':
|
||||
logger.info(
|
||||
'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
|
||||
)
|
||||
infer_shape[-1] = 100
|
||||
model = to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(
|
||||
shape=[None] + infer_shape, dtype='float32')
|
||||
])
|
||||
|
||||
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
|
||||
paddle.jit.save(model, save_path)
|
||||
logger.info('inference model is saved to {}'.format(save_path))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||
main(config, device, logger, vdl_writer)
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..'))
|
||||
sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
|
||||
|
||||
import paddle
|
||||
import paddle.distributed as dist
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.losses import build_loss
|
||||
from ppocr.optimizer import build_optimizer
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
import tools.program as program
|
||||
|
||||
dist.get_world_size()
|
||||
|
||||
|
||||
def get_pruned_params(parameters):
|
||||
params = []
|
||||
|
||||
for param in parameters:
|
||||
if len(
|
||||
param.shape
|
||||
) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
|
||||
params.append(param.name)
|
||||
return params
|
||||
|
||||
|
||||
def main(config, device, logger, vdl_writer):
|
||||
# init dist environment
|
||||
if config['Global']['distributed']:
|
||||
dist.init_parallel_env()
|
||||
|
||||
global_config = config['Global']
|
||||
|
||||
# build dataloader
|
||||
train_dataloader = build_dataloader(config, 'Train', device, logger)
|
||||
if config['Eval']:
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
else:
|
||||
valid_dataloader = None
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
global_config)
|
||||
|
||||
# build model
|
||||
# for rec algorithm
|
||||
if hasattr(post_process_class, 'character'):
|
||||
char_num = len(getattr(post_process_class, 'character'))
|
||||
config['Architecture']["Head"]['out_channels'] = char_num
|
||||
model = build_model(config['Architecture'])
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs before pruning: {flops}")
|
||||
|
||||
from paddleslim.dygraph import FPGMFilterPruner
|
||||
model.train()
|
||||
pruner = FPGMFilterPruner(model, [1, 3, 640, 640])
|
||||
|
||||
# build loss
|
||||
loss_class = build_loss(config['Loss'])
|
||||
|
||||
# build optim
|
||||
optimizer, lr_scheduler = build_optimizer(
|
||||
config['Optimizer'],
|
||||
epochs=config['Global']['epoch_num'],
|
||||
step_each_epoch=len(train_dataloader),
|
||||
parameters=model.parameters())
|
||||
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
# load pretrain model
|
||||
pre_best_model_dict = init_model(config, model, logger, optimizer)
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
# build metric
|
||||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
|
||||
format(len(train_dataloader), len(valid_dataloader)))
|
||||
|
||||
def eval_fn():
|
||||
metric = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info(f"metric['hmean']: {metric['hmean']}")
|
||||
return metric['hmean']
|
||||
|
||||
params_sensitive = pruner.sensitive(
|
||||
eval_func=eval_fn,
|
||||
sen_file="./sen.pickle",
|
||||
skip_vars=[
|
||||
"conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
|
||||
])
|
||||
|
||||
logger.info(
|
||||
"The sensitivity analysis results of model parameters saved in sen.pickle"
|
||||
)
|
||||
# calculate pruned params's ratio
|
||||
params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
|
||||
for key in params_sensitive.keys():
|
||||
logger.info(f"{key}, {params_sensitive[key]}")
|
||||
|
||||
plan = pruner.prune_vars(params_sensitive, [0])
|
||||
for param in model.parameters():
|
||||
if ("weights" in param.name and "conv" in param.name) or (
|
||||
"w_0" in param.name and "conv2d" in param.name):
|
||||
logger.info(f"{param.name}: {param.shape}")
|
||||
|
||||
flops = paddle.flops(model, [1, 3, 640, 640])
|
||||
logger.info(f"FLOPs after pruning: {flops}")
|
||||
|
||||
# start train
|
||||
|
||||
program.train(config, train_dataloader, valid_dataloader, device, model,
|
||||
loss_class, optimizer, lr_scheduler, post_process_class,
|
||||
eval_class, pre_best_model_dict, logger, vdl_writer)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config, device, logger, vdl_writer = program.preprocess(is_train=True)
|
||||
main(config, device, logger, vdl_writer)
|
|
@ -28,7 +28,9 @@ PaddleOCR开源的文本检测算法列表:
|
|||
| --- | --- | --- | --- | --- | --- |
|
||||
|SAST|ResNet50_vd|89.63%|78.44%|83.66%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
|
||||
|
||||
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:[百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
|
||||
**说明:** SAST模型训练额外加入了icdar2013、icdar2017、COCO-Text、ArT等公开数据集进行调优。PaddleOCR用到的经过整理格式的英文公开数据集下载:
|
||||
* [百度云地址](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (提取码: 2bpi)
|
||||
* [Google Drive下载地址](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing)
|
||||
|
||||
PaddleOCR文本检测算法的训练和使用请参考文档教程中[模型训练/评估中的文本检测部分](./detection.md)。
|
||||
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
## 文字角度分类
|
||||
### 方法介绍
|
||||
文字角度分类主要用于图片非0度的场景下,在这种场景下需要对图片里检测到的文本行进行一个转正的操作。在PaddleOCR系统内,
|
||||
文字检测之后得到的文本行图片经过仿射变换之后送入识别模型,此时只需要对文字进行一个0和180度的角度分类,因此PaddleOCR内置的
|
||||
文字角度分类器**只支持了0和180度的分类**。如果想支持更多角度,可以自己修改算法进行支持。
|
||||
|
||||
0和180度数据样本例子:
|
||||
|
||||
![](../imgs_results/angle_class_example.jpg)
|
||||
|
||||
### 数据准备
|
||||
|
||||
|
@ -13,7 +21,7 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
|
|||
请参考下文组织您的数据。
|
||||
- 训练集
|
||||
|
||||
首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。
|
||||
首先建议将训练图片放入同一个文件夹,并用一个txt文件(cls_gt_train.txt)记录图片路径和标签。
|
||||
|
||||
**注意:** 默认请将图片路径和图片标签用 `\t` 分割,如用其他方式分割将造成训练报错
|
||||
|
||||
|
@ -21,8 +29,8 @@ ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/cls/dataset
|
|||
|
||||
```
|
||||
" 图像文件名 图像标注信息 "
|
||||
train/word_001.jpg 0
|
||||
train/word_002.jpg 180
|
||||
train/cls/train/word_001.jpg 0
|
||||
train/cls/train/word_002.jpg 180
|
||||
```
|
||||
|
||||
最终训练集应有如下文件结构:
|
||||
|
|
|
@ -2,16 +2,18 @@
|
|||
# 基于Python预测引擎推理
|
||||
|
||||
inference 模型(`paddle.jit.save`保存的模型)
|
||||
一般是模型训练完成后保存的固化模型,多用于预测部署。训练过程中保存的模型是checkpoints模型,保存的是模型的参数,多用于恢复训练等。
|
||||
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合与实际系统集成。
|
||||
一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。
|
||||
训练过程中保存的模型是checkpoints模型,保存的只有模型的参数,多用于恢复训练等。
|
||||
与checkpoints模型相比,inference 模型会额外保存模型的结构信息,在预测部署、加速推理上性能优越,灵活方便,适合于实际系统集成。
|
||||
|
||||
接下来首先介绍如何将训练的模型转换成inference模型,然后将依次介绍文本检测、文本角度分类器、文本识别以及三者串联基于预测引擎推理。
|
||||
接下来首先介绍如何将训练的模型转换成inference模型,然后将依次介绍文本检测、文本角度分类器、文本识别以及三者串联在CPU、GPU上的预测方法。
|
||||
|
||||
|
||||
- [一、训练模型转inference模型](#训练模型转inference模型)
|
||||
- [检测模型转inference模型](#检测模型转inference模型)
|
||||
- [识别模型转inference模型](#识别模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [方向分类模型转inference模型](#方向分类模型转inference模型)
|
||||
- [端到端模型转inference模型](#端到端模型转inference模型)
|
||||
|
||||
- [二、文本检测模型推理](#文本检测模型推理)
|
||||
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
|
||||
|
@ -26,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
|
|||
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
|
||||
- [5. 多语言模型的推理](#多语言模型的推理)
|
||||
|
||||
- [四、方向分类模型推理](#方向识别模型推理)
|
||||
- [四、端到端模型推理](#端到端模型推理)
|
||||
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
|
||||
|
||||
- [五、方向分类模型推理](#方向识别模型推理)
|
||||
- [1. 方向分类模型推理](#方向分类模型推理)
|
||||
|
||||
- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
|
||||
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
|
||||
- [2. 其他模型推理](#其他模型推理)
|
||||
|
||||
|
@ -117,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
|
|||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
<a name="端到端模型转inference模型"></a>
|
||||
### 端到端模型转inference模型
|
||||
|
||||
下载端到端模型:
|
||||
```
|
||||
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
|
||||
```
|
||||
|
||||
端到端模型转inference模型与检测的方式相同,如下:
|
||||
```
|
||||
# -c 后面设置训练算法的yml配置文件
|
||||
# -o 配置可选参数
|
||||
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
|
||||
# Global.load_static_weights 参数需要设置为 False。
|
||||
# Global.save_inference_dir参数设置转换的模型将保存的地址。
|
||||
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
|
||||
```
|
||||
|
||||
转换成功后,在目录下有三个文件:
|
||||
```
|
||||
/inference/e2e/
|
||||
├── inference.pdiparams # 分类inference模型的参数文件
|
||||
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
|
||||
└── inference.pdmodel # 分类inference模型的program文件
|
||||
```
|
||||
|
||||
<a name="文本检测模型推理"></a>
|
||||
## 二、文本检测模型推理
|
||||
|
@ -140,7 +171,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_m
|
|||
![](../imgs_results/det_res_00018069.jpg)
|
||||
|
||||
通过参数`limit_type`和`det_limit_side_len`来对图片的尺寸进行限制,
|
||||
`litmit_type`可选参数为[`max`, `min`],
|
||||
`limit_type`可选参数为[`max`, `min`],
|
||||
`det_limit_size_len` 为正整数,一般设置为32 的倍数,比如960。
|
||||
|
||||
参数默认设置为`limit_type='max', det_limit_side_len=960`。表示网络输入图像的最长边不能超过960,
|
||||
|
@ -331,8 +362,38 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
|
|||
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
|
||||
```
|
||||
|
||||
<a name="端到端模型推理"></a>
|
||||
## 四、端到端模型推理
|
||||
|
||||
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
|
||||
<a name="PGNet端到端模型推理"></a>
|
||||
### 1. PGNet端到端模型推理
|
||||
#### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar)),可以使用如下命令进行转换:
|
||||
```
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||
```
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||
|
||||
#### (2). 弯曲文本检测模型(Total-Text)
|
||||
和四边形文本检测模型共用一个推理模型
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3.7 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
```
|
||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
||||
|
||||
|
||||
<a name="方向分类模型推理"></a>
|
||||
## 四、方向分类模型推理
|
||||
## 五、方向分类模型推理
|
||||
|
||||
下面将介绍方向分类模型推理。
|
||||
|
||||
|
@ -357,7 +418,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
|
|||
```
|
||||
|
||||
<a name="文本检测、方向分类和文字识别串联推理"></a>
|
||||
## 五、文本检测、方向分类和文字识别串联推理
|
||||
## 六、文本检测、方向分类和文字识别串联推理
|
||||
<a name="超轻量中文OCR模型推理"></a>
|
||||
### 1. 超轻量中文OCR模型推理
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=hos
|
|||
sudo docker container exec -it ppocr /bin/bash
|
||||
```
|
||||
|
||||
**2. 安装PaddlePaddle Fluid v2.0**
|
||||
**2. 安装PaddlePaddle 2.0**
|
||||
```
|
||||
pip3 install --upgrade pip
|
||||
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
# 端对端OCR算法-PGNet
|
||||
- [一、简介](#简介)
|
||||
- [二、环境配置](#环境配置)
|
||||
- [三、快速使用](#快速使用)
|
||||
- [四、快速训练](#开始训练)
|
||||
- [五、预测推理](#预测推理)
|
||||
|
||||
|
||||
<a name="简介"></a>
|
||||
##简介
|
||||
OCR算法可以分为两阶段算法和端对端的算法。二阶段OCR算法一般分为两个部分,文本检测和文本识别算法,文件检测算法从图像中得到文本行的检测框,然后识别算法去识别文本框中的内容。而端对端OCR算法可以在一个算法中完成文字检测和文字识别,其基本思想是设计一个同时具有检测单元和识别模块的模型,共享其中两者的CNN特征,并联合训练。由于一个算法即可完成文字识别,端对端模型更小,速度更快。
|
||||
|
||||
### PGNet算法介绍
|
||||
近些年来,端对端OCR算法得到了良好的发展,包括MaskTextSpotter系列、TextSnake、TextDragon、PGNet系列等算法。在这些算法中,PGNet算法具备其他算法不具备的优势,包括:
|
||||
- 设计PGNet loss指导训练,不需要字符级别的标注
|
||||
- 不需要NMS和ROI相关操作,加速预测
|
||||
- 提出预测文本行内的阅读顺序模块;
|
||||
- 提出基于图的修正模块(GRM)来进一步提高模型识别性能
|
||||
- 精度更高,预测速度更快
|
||||
|
||||
PGNet算法细节详见[论文](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf), 算法原理图如下所示:
|
||||
![](../pgnet_framework.png)
|
||||
输入图像经过特征提取送入四个分支,分别是:文本边缘偏移量预测TBO模块,文本中心线预测TCL模块,文本方向偏移量预测TDO模块,以及文本字符分类图预测TCC模块。
|
||||
其中TBO以及TCL的输出经过后处理后可以得到文本的检测结果,TCL、TDO、TCC负责文本识别。
|
||||
其检测识别效果图如下:
|
||||
![](../imgs_results/e2e_res_img293_pgnet.png)
|
||||
![](../imgs_results/e2e_res_img295_pgnet.png)
|
||||
|
||||
<a name="环境配置"></a>
|
||||
##环境配置
|
||||
请先参考[快速安装](./installation.md)配置PaddleOCR运行环境。
|
||||
|
||||
*注意:也可以通过 whl 包安装使用PaddleOCR,具体参考[Paddleocr Package使用说明](./whl.md)。*
|
||||
|
||||
<a name="快速使用"></a>
|
||||
##快速使用
|
||||
### inference模型下载
|
||||
本节以训练好的端到端模型为例,快速使用模型预测,首先下载训练好的端到端inference模型[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar)
|
||||
```
|
||||
mkdir inference && cd inference
|
||||
# 下载英文端到端模型并解压
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/e2e_server_pgnetA_infer.tar && tar xf e2e_server_pgnetA_infer.tar
|
||||
```
|
||||
* windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下
|
||||
|
||||
解压完毕后应有如下文件结构:
|
||||
```
|
||||
├── e2e_server_pgnetA_infer
|
||||
│ ├── inference.pdiparams
|
||||
│ ├── inference.pdiparams.info
|
||||
│ └── inference.pdmodel
|
||||
```
|
||||
### 单张图像或者图像集合预测
|
||||
```bash
|
||||
# 预测image_dir指定的单张图像
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
|
||||
# 预测image_dir指定的图像集合
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
|
||||
# 如果想使用CPU进行预测,需设置use_gpu参数为False
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True --use_gpu=False
|
||||
```
|
||||
<a name="开始训练"></a>
|
||||
##开始训练
|
||||
本节以totaltext数据集为例,介绍PaddleOCR中端到端模型的训练、评估与测试。
|
||||
###数据形式为icdar, 十六点标注数据
|
||||
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
|
||||
```
|
||||
/PaddleOCR/train_data/total_text/train/
|
||||
|- rgb/ total_text数据集的训练数据
|
||||
|- gt_0.png
|
||||
| ...
|
||||
|- total_text.txt total_text数据集的训练标注
|
||||
```
|
||||
|
||||
提供的标注文件格式如下,中间用"\t"分隔:
|
||||
```
|
||||
" 图像文件名 json.dumps编码的图像标注信息"
|
||||
rgb/gt_0.png [{"transcription": "EST", "points": [[1004.0,689.0],[1019.0,698.0],[1034.0,708.0],[1049.0,718.0],[1064.0,728.0],[1079.0,738.0],[1095.0,748.0],[1094.0,774.0],[1079.0,765.0],[1065.0,756.0],[1050.0,747.0],[1036.0,738.0],[1021.0,729.0],[1007.0,721.0]]}, {...}]
|
||||
```
|
||||
json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 `points` 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。
|
||||
`transcription` 表示当前文本框的文字,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
|
||||
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
|
||||
|
||||
### 快速启动训练
|
||||
|
||||
模型训练一般分两步骤进行,第一步可以选择用合成数据训练,第二步加载第一步训练好的模型训练,这边我们提供了第一步训练好的模型,可以直接加载,从第二步开始训练
|
||||
[下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar)
|
||||
```shell
|
||||
cd PaddleOCR/
|
||||
下载ResNet50_vd的动态图预训练模型
|
||||
wget -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/train_step1.tar
|
||||
可以得到以下的文件格式
|
||||
./pretrain_models/train_step1/
|
||||
└─ best_accuracy.pdopt
|
||||
└─ best_accuracy.states
|
||||
└─ best_accuracy.pdparams
|
||||
|
||||
```
|
||||
|
||||
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
|
||||
|
||||
```shell
|
||||
# 单机单卡训练 e2e 模型
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
|
||||
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
|
||||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./pretrain_models/train_step1/best_accuracy Global.load_static_weights=False
|
||||
```
|
||||
|
||||
上述指令中,通过-c 选择训练使用configs/e2e/e2e_r50_vd_pg.yml配置文件。
|
||||
有关配置文件的详细解释,请参考[链接](./config.md)。
|
||||
|
||||
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
|
||||
```shell
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Optimizer.base_lr=0.0001
|
||||
```
|
||||
|
||||
#### 断点训练
|
||||
|
||||
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
|
||||
```shell
|
||||
python3 tools/train.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints=./your/trained/model
|
||||
```
|
||||
|
||||
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
|
||||
|
||||
<a name="预测推理"></a>
|
||||
## 预测推理
|
||||
|
||||
PaddleOCR计算三个OCR端到端相关的指标,分别是:Precision、Recall、Hmean。
|
||||
|
||||
运行如下代码,根据配置文件`e2e_r50_vd_pg.yml`中`save_res_path`指定的测试集检测结果文件,计算评估指标。
|
||||
|
||||
评估时设置后处理参数`max_side_len=768`,使用不同数据集、不同模型训练,可调整参数进行优化
|
||||
训练中模型参数默认保存在`Global.save_model_dir`目录下。在评估指标时,需要设置`Global.checkpoints`指向保存的参数文件。
|
||||
```shell
|
||||
python3 tools/eval.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.checkpoints="{path/to/weights}/best_accuracy"
|
||||
```
|
||||
|
||||
### 测试端到端效果
|
||||
测试单张图像的端到端识别效果
|
||||
```shell
|
||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
||||
```
|
||||
|
||||
测试文件夹下所有图像的端到端识别效果
|
||||
```shell
|
||||
python3 tools/infer_e2e.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy" Global.load_static_weights=false
|
||||
```
|
||||
|
||||
###转为推理模型
|
||||
### (1). 四边形文本检测模型(ICDAR2015)
|
||||
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,以英文数据集训练的模型为例[模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar) ,可以使用如下命令进行转换:
|
||||
```
|
||||
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/pgnet/en_server_pgnetA.tar && tar xf en_server_pgnetA.tar
|
||||
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./en_server_pgnetA/iter_epoch_450 Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
|
||||
```
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
|
||||
```
|
||||
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
|
||||
|
||||
### (2). 弯曲文本检测模型(Total-Text)
|
||||
对于弯曲文本样例
|
||||
|
||||
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
|
||||
```
|
||||
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
|
||||
```
|
||||
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
|
||||
|
||||
![](../imgs_results/e2e_res_img623_pgnet.jpg)
|
|
@ -1,61 +1,94 @@
|
|||
## 文字识别
|
||||
|
||||
|
||||
- [一、数据准备](#数据准备)
|
||||
- [数据下载](#数据下载)
|
||||
- [自定义数据集](#自定义数据集)
|
||||
- [字典](#字典)
|
||||
- [支持空格](#支持空格)
|
||||
- [1 数据准备](#数据准备)
|
||||
- [1.1 自定义数据集](#自定义数据集)
|
||||
- [1.2 数据下载](#数据下载)
|
||||
- [1.3 字典](#字典)
|
||||
- [1.4 支持空格](#支持空格)
|
||||
|
||||
- [二、启动训练](#启动训练)
|
||||
- [1. 数据增强](#数据增强)
|
||||
- [2. 训练](#训练)
|
||||
- [3. 小语种](#小语种)
|
||||
- [2 启动训练](#启动训练)
|
||||
- [2.1 数据增强](#数据增强)
|
||||
- [2.2 训练](#训练)
|
||||
- [2.3 小语种](#小语种)
|
||||
|
||||
- [三、评估](#评估)
|
||||
- [3 评估](#评估)
|
||||
|
||||
- [四、预测](#预测)
|
||||
- [1. 训练引擎预测](#训练引擎预测)
|
||||
- [4 预测](#预测)
|
||||
- [4.1 训练引擎预测](#训练引擎预测)
|
||||
|
||||
|
||||
<a name="数据准备"></a>
|
||||
### 数据准备
|
||||
### 1. 数据准备
|
||||
|
||||
|
||||
PaddleOCR 支持两种数据格式: `lmdb` 用于训练公开数据,调试算法; `通用数据` 训练自己的数据:
|
||||
|
||||
请按如下步骤设置数据集:
|
||||
PaddleOCR 支持两种数据格式:
|
||||
- `lmdb` 用于训练以lmdb格式存储的数据集;
|
||||
- `通用数据` 用于训练以文本文件存储的数据集:
|
||||
|
||||
训练数据的默认存储路径是 `PaddleOCR/train_data`,如果您的磁盘上已有数据集,只需创建软链接至数据集目录:
|
||||
|
||||
```
|
||||
# linux and mac os
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
||||
# windows
|
||||
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
|
||||
```
|
||||
|
||||
<a name="数据下载"></a>
|
||||
* 数据下载
|
||||
<a name="准备数据集"></a>
|
||||
#### 1.1 自定义数据集
|
||||
下面以通用数据集为例, 介绍如何准备数据集:
|
||||
|
||||
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。
|
||||
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
|
||||
* 训练集
|
||||
|
||||
<a name="自定义数据集"></a>
|
||||
* 使用自己数据集
|
||||
建议将训练图片放入同一个文件夹,并用一个txt文件(rec_gt_train.txt)记录图片路径和标签,txt文件里的内容如下:
|
||||
|
||||
若您希望使用自己的数据进行训练,请参考下文组织您的数据。
|
||||
|
||||
- 训练集
|
||||
|
||||
首先请将训练图片放入同一个文件夹(train_images),并用一个txt文件(rec_gt_train.txt)记录图片路径和标签。
|
||||
|
||||
**注意:** 默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错
|
||||
**注意:** txt文件中默认请将图片路径和图片标签用 \t 分割,如用其他方式分割将造成训练报错。
|
||||
|
||||
```
|
||||
" 图像文件名 图像标注信息 "
|
||||
|
||||
train_data/train_0001.jpg 简单可依赖
|
||||
train_data/train_0002.jpg 用科技让复杂的世界更简单
|
||||
train_data/rec/train/word_001.jpg 简单可依赖
|
||||
train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
|
||||
...
|
||||
```
|
||||
PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载:
|
||||
|
||||
最终训练集应有如下文件结构:
|
||||
```
|
||||
|-train_data
|
||||
|-rec
|
||||
|- rec_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- 测试集
|
||||
|
||||
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-rec
|
||||
|- rec_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
<a name="数据下载"></a>
|
||||
|
||||
1.2 数据下载
|
||||
|
||||
若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。
|
||||
|
||||
如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载:
|
||||
|
||||
如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
|
||||
|
||||
```
|
||||
# 训练集标签
|
||||
|
@ -71,34 +104,8 @@ PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支
|
|||
python gen_label.py --mode="rec" --input_path="{path/of/origin/label}" --output_label="rec_gt_label.txt"
|
||||
```
|
||||
|
||||
最终训练集应有如下文件结构:
|
||||
```
|
||||
|-train_data
|
||||
|-ic15_data
|
||||
|- rec_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- 测试集
|
||||
|
||||
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-ic15_data
|
||||
|- rec_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
<a name="字典"></a>
|
||||
- 字典
|
||||
1.3 字典
|
||||
|
||||
最后需要提供一个字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。
|
||||
|
||||
|
@ -115,6 +122,10 @@ n
|
|||
|
||||
word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,“and” 将被映射成 [2 5 1]
|
||||
|
||||
* 内置字典
|
||||
|
||||
PaddleOCR内置了一部分字典,可以按需使用。
|
||||
|
||||
`ppocr/utils/ppocr_keys_v1.txt` 是一个包含6623个字符的中文字典
|
||||
|
||||
`ppocr/utils/ic15_dict.txt` 是一个包含36个字符的英文字典
|
||||
|
@ -130,7 +141,7 @@ word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,
|
|||
`ppocr/utils/dict/en_dict.txt` 是一个包含63个字符的英文字典
|
||||
|
||||
|
||||
您可以按需使用。
|
||||
|
||||
|
||||
目前的多语言模型仍处在demo阶段,会持续优化模型并补充语种,**非常欢迎您为我们提供其他语言的字典和字体**,
|
||||
如您愿意可将字典文件提交至 [dict](../../ppocr/utils/dict),我们会在Repo中感谢您。
|
||||
|
@ -141,13 +152,13 @@ word_dict.txt 每行有一个单字,将字符与数字索引映射在一起,
|
|||
并将 `character_type` 设置为 `ch`。
|
||||
|
||||
<a name="支持空格"></a>
|
||||
- 添加空格类别
|
||||
1.4 添加空格类别
|
||||
|
||||
如果希望支持识别"空格"类别, 请将yml文件中的 `use_space_char` 字段设置为 `True`。
|
||||
|
||||
|
||||
<a name="启动训练"></a>
|
||||
### 启动训练
|
||||
### 2. 启动训练
|
||||
|
||||
PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN 识别模型为例:
|
||||
|
||||
|
@ -172,7 +183,7 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
|
|||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
```
|
||||
<a name="数据增强"></a>
|
||||
- 数据增强
|
||||
#### 2.1 数据增强
|
||||
|
||||
PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。
|
||||
|
||||
|
@ -183,7 +194,7 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
|
|||
*由于OpenCV的兼容性问题,扰动操作暂时只支持Linux*
|
||||
|
||||
<a name="训练"></a>
|
||||
- 训练
|
||||
#### 2.2 训练
|
||||
|
||||
PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec_CRNN/best_accuracy` 。
|
||||
|
||||
|
@ -272,7 +283,7 @@ Eval:
|
|||
**注意,预测/评估时的配置文件请务必与训练一致。**
|
||||
|
||||
<a name="小语种"></a>
|
||||
- 小语种
|
||||
#### 2.3 小语种
|
||||
|
||||
PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
|
||||
|
||||
|
@ -415,7 +426,7 @@ Eval:
|
|||
...
|
||||
```
|
||||
<a name="评估"></a>
|
||||
### 评估
|
||||
### 3 评估
|
||||
|
||||
评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。
|
||||
|
||||
|
@ -425,10 +436,10 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
|
|||
```
|
||||
|
||||
<a name="预测"></a>
|
||||
### 预测
|
||||
### 4 预测
|
||||
|
||||
<a name="训练引擎预测"></a>
|
||||
* 训练引擎的预测
|
||||
#### 4.1 训练引擎的预测
|
||||
|
||||
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
# paddleocr package使用说明
|
||||
|
||||
## 快速上手
|
||||
## 1 快速上手
|
||||
|
||||
### 安装whl包
|
||||
### 1.1 安装whl包
|
||||
|
||||
pip安装
|
||||
```bash
|
||||
|
@ -14,9 +14,12 @@ pip install "paddleocr>=2.0.1" # 推荐使用2.0.1+版本
|
|||
python3 setup.py bdist_wheel
|
||||
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x是paddleocr的版本号
|
||||
```
|
||||
### 1. 代码使用
|
||||
|
||||
* 检测+分类+识别全流程
|
||||
## 2 使用
|
||||
### 2.1 代码使用
|
||||
paddleocr whl包会自动下载ppocr轻量级模型作为默认模型,可以根据第3节**自定义模型**进行自定义更换。
|
||||
|
||||
* 检测+方向分类器+识别全流程
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
|
@ -33,7 +36,7 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -66,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -84,7 +87,7 @@ im_show.save('result.jpg')
|
|||
</div>
|
||||
|
||||
|
||||
* 分类+识别
|
||||
* 方向分类器+识别
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
|
@ -111,7 +114,7 @@ for line in result:
|
|||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -143,7 +146,7 @@ for line in result:
|
|||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行分类
|
||||
* 单独执行方向分类器
|
||||
```python
|
||||
from paddleocr import PaddleOCR
|
||||
ocr = PaddleOCR(use_angle_cls=True) # need to run only once to download and load model into memory
|
||||
|
@ -157,14 +160,14 @@ for line in result:
|
|||
['0', 0.9999924]
|
||||
```
|
||||
|
||||
### 通过命令行使用
|
||||
### 2.2 通过命令行使用
|
||||
|
||||
查看帮助信息
|
||||
```bash
|
||||
paddleocr -h
|
||||
```
|
||||
|
||||
* 检测+分类+识别全流程
|
||||
* 检测+方向分类器+识别全流程
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
|
||||
```
|
||||
|
@ -188,7 +191,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg
|
|||
......
|
||||
```
|
||||
|
||||
* 分类+识别
|
||||
* 方向分类器+识别
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false
|
||||
```
|
||||
|
@ -220,7 +223,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --det false
|
|||
['韩国小馆', 0.9907421]
|
||||
```
|
||||
|
||||
* 单独执行分类
|
||||
* 单独执行方向分类器
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls true --det false --rec false
|
||||
```
|
||||
|
@ -230,11 +233,11 @@ paddleocr --image_dir PaddleOCR/doc/imgs_words/ch/word_1.jpg --use_angle_cls tru
|
|||
['0', 0.9999924]
|
||||
```
|
||||
|
||||
## 自定义模型
|
||||
## 3 自定义模型
|
||||
当内置模型无法满足需求时,需要使用到自己训练的模型。
|
||||
首先,参照[inference.md](./inference.md) 第一节转换将检测、分类和识别模型转换为inference模型,然后按照如下方式使用
|
||||
|
||||
### 代码使用
|
||||
### 3.1 代码使用
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
# 模型路径下必须含有model和params文件
|
||||
|
@ -250,22 +253,22 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
### 通过命令行使用
|
||||
### 3.2 通过命令行使用
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true
|
||||
```
|
||||
|
||||
### 使用网络图片或者numpy数组作为输入
|
||||
## 4 使用网络图片或者numpy数组作为输入
|
||||
|
||||
1. 网络图片
|
||||
### 4.1 网络图片
|
||||
|
||||
代码使用
|
||||
- 代码使用
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换
|
||||
|
@ -282,16 +285,16 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
命令行模式
|
||||
- 命令行模式
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
2. numpy数组
|
||||
### 4.2 numpy数组
|
||||
仅通过代码使用时支持numpy数组作为输入
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
|
@ -301,7 +304,7 @@ ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to downlo
|
|||
img_path = 'PaddleOCR/doc/imgs/11.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消
|
||||
result = ocr.ocr(img_path, cls=True)
|
||||
result = ocr.ocr(img, cls=True)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
|
@ -311,12 +314,12 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
## 参数说明
|
||||
## 5 参数说明
|
||||
|
||||
| 字段 | 说明 | 默认值 |
|
||||
|-------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
|
||||
|
|
|
@ -31,7 +31,9 @@ On Total-Text dataset, the text detection result is as follows:
|
|||
| --- | --- | --- | --- | --- | --- |
|
||||
|SAST|ResNet50_vd|89.63%|78.44%|83.66%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)|
|
||||
|
||||
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
|
||||
**Note:** Additional data, like icdar2013, icdar2017, COCO-Text, ArT, was added to the model training of SAST. Download English public dataset in organized format used by PaddleOCR from:
|
||||
* [Baidu Drive](https://pan.baidu.com/s/12cPnZcVuV1zn5DOd4mqjVw) (download code: 2bpi).
|
||||
* [Google Drive](https://drive.google.com/drive/folders/1ll2-XEVyCQLpJjawLDiRlvo_i4BqHCJe?usp=sharing)
|
||||
|
||||
For the training guide and use of PaddleOCR text detection algorithms, please refer to the document [Text detection model training/evaluation/prediction](./detection_en.md)
|
||||
|
||||
|
|
|
@ -1,5 +1,12 @@
|
|||
## TEXT ANGLE CLASSIFICATION
|
||||
|
||||
### Method introduction
|
||||
The angle classification is used in the scene where the image is not 0 degrees. In this scene, it is necessary to perform a correction operation on the text line detected in the picture. In the PaddleOCR system,
|
||||
The text line image obtained after text detection is sent to the recognition model after affine transformation. At this time, only a 0 and 180 degree angle classification of the text is required, so the built-in PaddleOCR text angle classifier **only supports 0 and 180 degree classification**. If you want to support more angles, you can modify the algorithm yourself to support.
|
||||
|
||||
Example of 0 and 180 degree data samples:
|
||||
|
||||
![](../imgs_results/angle_class_example.jpg)
|
||||
### DATA PREPARATION
|
||||
|
||||
Please organize the dataset as follows:
|
||||
|
|
|
@ -5,7 +5,8 @@ The inference model (the model saved by `paddle.jit.save`) is generally a solidi
|
|||
|
||||
The model saved during the training process is the checkpoints model, which saves the parameters of the model and is mostly used to resume training.
|
||||
|
||||
Compared with the checkpoints model, the inference model will additionally save the structural information of the model. It has superior performance in predicting in deployment and accelerating inferencing, is flexible and convenient, and is suitable for integration with actual systems. For more details, please refer to the document [Classification Framework](https://github.com/PaddlePaddle/PaddleClas/blob/master/docs/zh_CN/extension/paddle_inference.md).
|
||||
Compared with the checkpoints model, the inference model will additionally save the structural information of the model. Therefore, it is easier to deploy because the model structure and model parameters are already solidified in the inference model file, and is suitable for integration with actual systems.
|
||||
For more details, please refer to the document [Classification Framework](https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/docs/zh_CN/extension/paddle_mobile_inference.md).
|
||||
|
||||
Next, we first introduce how to convert a trained model into an inference model, and then we will introduce text detection, text recognition, angle class, and the concatenation of them based on inference model.
|
||||
|
||||
|
@ -147,7 +148,7 @@ The visual text detection results are saved to the ./inference_results folder by
|
|||
![](../imgs_results/det_res_00018069.jpg)
|
||||
|
||||
You can use the parameters `limit_type` and `det_limit_side_len` to limit the size of the input image,
|
||||
The optional parameters of `litmit_type` are [`max`, `min`], and
|
||||
The optional parameters of `limit_type` are [`max`, `min`], and
|
||||
`det_limit_size_len` is a positive integer, generally set to a multiple of 32, such as 960.
|
||||
|
||||
The default setting of the parameters is `limit_type='max', det_limit_side_len=960`. Indicates that the longest side of the network input image cannot exceed 960,
|
||||
|
|
|
@ -33,7 +33,7 @@ You can also visit [DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags
|
|||
sudo docker container exec -it ppocr /bin/bash
|
||||
```
|
||||
|
||||
**2. Install PaddlePaddle Fluid v2.0**
|
||||
**2. Install PaddlePaddle 2.0**
|
||||
```
|
||||
pip3 install --upgrade pip
|
||||
|
||||
|
|
|
@ -1,59 +1,95 @@
|
|||
## TEXT RECOGNITION
|
||||
|
||||
- [DATA PREPARATION](#DATA_PREPARATION)
|
||||
- [Dataset Download](#Dataset_download)
|
||||
- [Costom Dataset](#Costom_Dataset)
|
||||
- [Dictionary](#Dictionary)
|
||||
- [Add Space Category](#Add_space_category)
|
||||
- [1 DATA PREPARATION](#DATA_PREPARATION)
|
||||
- [1.1 Costom Dataset](#Costom_Dataset)
|
||||
- [1.2 Dataset Download](#Dataset_download)
|
||||
- [1.3 Dictionary](#Dictionary)
|
||||
- [1.4 Add Space Category](#Add_space_category)
|
||||
|
||||
- [TRAINING](#TRAINING)
|
||||
- [Data Augmentation](#Data_Augmentation)
|
||||
- [Training](#Training)
|
||||
- [Multi-language](#Multi_language)
|
||||
- [2 TRAINING](#TRAINING)
|
||||
- [2.1 Data Augmentation](#Data_Augmentation)
|
||||
- [2.2 Training](#Training)
|
||||
- [2.3 Multi-language](#Multi_language)
|
||||
|
||||
- [EVALUATION](#EVALUATION)
|
||||
- [3 EVALUATION](#EVALUATION)
|
||||
|
||||
- [PREDICTION](#PREDICTION)
|
||||
- [Training engine prediction](#Training_engine_prediction)
|
||||
- [4 PREDICTION](#PREDICTION)
|
||||
- [4.1 Training engine prediction](#Training_engine_prediction)
|
||||
|
||||
<a name="DATA_PREPARATION"></a>
|
||||
### DATA PREPARATION
|
||||
|
||||
|
||||
PaddleOCR supports two data formats: `LMDB` is used to train public data and evaluation algorithms; `general data` is used to train your own data:
|
||||
PaddleOCR supports two data formats:
|
||||
- `LMDB` is used to train data sets stored in lmdb format;
|
||||
- `general data` is used to train data sets stored in text files:
|
||||
|
||||
Please organize the dataset as follows:
|
||||
|
||||
The default storage path for training data is `PaddleOCR/train_data`, if you already have a dataset on your disk, just create a soft link to the dataset directory:
|
||||
|
||||
```
|
||||
# linux and mac os
|
||||
ln -sf <path/to/dataset> <path/to/paddle_ocr>/train_data/dataset
|
||||
# windows
|
||||
mklink /d <path/to/paddle_ocr>/train_data/dataset <path/to/dataset>
|
||||
```
|
||||
|
||||
<a name="Dataset_download"></a>
|
||||
* Dataset download
|
||||
|
||||
If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads). Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),download the lmdb format dataset required for benchmark
|
||||
|
||||
If you want to reproduce the paper indicators of SRN, you need to download offline [augmented data](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA), extraction code: y3ry. The augmented data is obtained by rotation and perturbation of mjsynth and synthtext. Please unzip the data to {your_path}/PaddleOCR/train_data/data_lmdb_Release/training/path.
|
||||
|
||||
<a name="Costom_Dataset"></a>
|
||||
* Use your own dataset:
|
||||
#### 1.1 Costom dataset
|
||||
|
||||
If you want to use your own data for training, please refer to the following to organize your data.
|
||||
|
||||
- Training set
|
||||
|
||||
First put the training images in the same folder (train_images), and use a txt file (rec_gt_train.txt) to store the image path and label.
|
||||
It is recommended to put the training images in the same folder, and use a txt file (rec_gt_train.txt) to store the image path and label. The contents of the txt file are as follows:
|
||||
|
||||
* Note: by default, the image path and image label are split with \t, if you use other methods to split, it will cause training error
|
||||
|
||||
```
|
||||
" Image file name Image annotation "
|
||||
|
||||
train_data/train_0001.jpg 简单可依赖
|
||||
train_data/train_0002.jpg 用科技让复杂的世界更简单
|
||||
train_data/rec/train/word_001.jpg 简单可依赖
|
||||
train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
|
||||
...
|
||||
```
|
||||
|
||||
The final training set should have the following file structure:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-rec
|
||||
|- rec_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- Test set
|
||||
|
||||
Similar to the training set, the test set also needs to be provided a folder containing all images (test) and a rec_gt_test.txt. The structure of the test set is as follows:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-rec
|
||||
|-ic15_data
|
||||
|- rec_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
<a name="Dataset_download"></a>
|
||||
#### 1.2 Dataset download
|
||||
|
||||
If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads). Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark
|
||||
|
||||
If you want to reproduce the paper indicators of SRN, you need to download offline [augmented data](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA), extraction code: y3ry. The augmented data is obtained by rotation and perturbation of mjsynth and synthtext. Please unzip the data to {your_path}/PaddleOCR/train_data/data_lmdb_Release/training/path.
|
||||
|
||||
PaddleOCR provides label files for training the icdar2015 dataset, which can be downloaded in the following ways:
|
||||
|
||||
```
|
||||
|
@ -63,35 +99,8 @@ wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_t
|
|||
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
|
||||
```
|
||||
|
||||
The final training set should have the following file structure:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-ic15_data
|
||||
|- rec_gt_train.txt
|
||||
|- train
|
||||
|- word_001.png
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
|
||||
- Test set
|
||||
|
||||
Similar to the training set, the test set also needs to be provided a folder containing all images (test) and a rec_gt_test.txt. The structure of the test set is as follows:
|
||||
|
||||
```
|
||||
|-train_data
|
||||
|-ic15_data
|
||||
|- rec_gt_test.txt
|
||||
|- test
|
||||
|- word_001.jpg
|
||||
|- word_002.jpg
|
||||
|- word_003.jpg
|
||||
| ...
|
||||
```
|
||||
<a name="Dictionary"></a>
|
||||
- Dictionary
|
||||
#### 1.3 Dictionary
|
||||
|
||||
Finally, a dictionary ({word_dict_name}.txt) needs to be provided so that when the model is trained, all the characters that appear can be mapped to the dictionary index.
|
||||
|
||||
|
@ -108,6 +117,8 @@ n
|
|||
|
||||
In `word_dict.txt`, there is a single word in each line, which maps characters and numeric indexes together, e.g "and" will be mapped to [2 5 1]
|
||||
|
||||
PaddleOCR has built-in dictionaries, which can be used on demand.
|
||||
|
||||
`ppocr/utils/ppocr_keys_v1.txt` is a Chinese dictionary with 6623 characters.
|
||||
|
||||
`ppocr/utils/ic15_dict.txt` is an English dictionary with 63 characters
|
||||
|
@ -123,8 +134,6 @@ In `word_dict.txt`, there is a single word in each line, which maps characters a
|
|||
`ppocr/utils/dict/en_dict.txt` is a English dictionary with 63 characters
|
||||
|
||||
|
||||
You can use it on demand.
|
||||
|
||||
The current multi-language model is still in the demo stage and will continue to optimize the model and add languages. **You are very welcome to provide us with dictionaries and fonts in other languages**,
|
||||
If you like, you can submit the dictionary file to [dict](../../ppocr/utils/dict) and we will thank you in the Repo.
|
||||
|
||||
|
@ -136,14 +145,14 @@ To customize the dict file, please modify the `character_dict_path` field in `co
|
|||
If you need to customize dic file, please add character_dict_path field in configs/rec/rec_icdar15_train.yml to point to your dictionary path. And set character_type to ch.
|
||||
|
||||
<a name="Add_space_category"></a>
|
||||
- Add space category
|
||||
#### 1.4 Add space category
|
||||
|
||||
If you want to support the recognition of the `space` category, please set the `use_space_char` field in the yml file to `True`.
|
||||
|
||||
**Note: use_space_char only takes effect when character_type=ch**
|
||||
|
||||
<a name="TRAINING"></a>
|
||||
### TRAINING
|
||||
### 2 TRAINING
|
||||
|
||||
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. In this section, the CRNN recognition model will be used as an example:
|
||||
|
||||
|
@ -166,7 +175,7 @@ Start training:
|
|||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
|
||||
```
|
||||
<a name="Data_Augmentation"></a>
|
||||
- Data Augmentation
|
||||
#### 2.1 Data Augmentation
|
||||
|
||||
PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file.
|
||||
|
||||
|
@ -175,7 +184,7 @@ The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, rand
|
|||
Each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: [img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
|
||||
|
||||
<a name="Training"></a>
|
||||
- Training
|
||||
#### 2.2 Training
|
||||
|
||||
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/rec/rec_icdar15_train.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/rec_CRNN/best_accuracy` during the evaluation process.
|
||||
|
||||
|
@ -268,7 +277,7 @@ Eval:
|
|||
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
|
||||
|
||||
<a name="Multi_language"></a>
|
||||
- Multi-language
|
||||
#### 2.3 Multi-language
|
||||
|
||||
PaddleOCR currently supports 26 (except Chinese) language recognition. A multi-language configuration file template is
|
||||
provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
|
||||
|
@ -420,7 +429,7 @@ Eval:
|
|||
```
|
||||
|
||||
<a name="EVALUATION"></a>
|
||||
### EVALUATION
|
||||
### 3 EVALUATION
|
||||
|
||||
The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/rec/rec_icdar15_train.yml` file.
|
||||
|
||||
|
@ -430,10 +439,10 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
|
|||
```
|
||||
|
||||
<a name="PREDICTION"></a>
|
||||
### PREDICTION
|
||||
### 4 PREDICTION
|
||||
|
||||
<a name="Training_engine_prediction"></a>
|
||||
* Training engine prediction
|
||||
#### 4.1 Training engine prediction
|
||||
|
||||
Using the model trained by paddleocr, you can quickly get prediction through the following script.
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# paddleocr package
|
||||
|
||||
## Get started quickly
|
||||
### install package
|
||||
## 1 Get started quickly
|
||||
### 1.1 install package
|
||||
install by pypi
|
||||
```bash
|
||||
pip install "paddleocr>=2.0.1" # Recommend to use version 2.0.1+
|
||||
|
@ -12,9 +12,11 @@ build own whl package and install
|
|||
python3 setup.py bdist_wheel
|
||||
pip3 install dist/paddleocr-x.x.x-py3-none-any.whl # x.x.x is the version of paddleocr
|
||||
```
|
||||
### 1. Use by code
|
||||
## 2 Use
|
||||
### 2.1 Use by code
|
||||
The paddleocr whl package will automatically download the ppocr lightweight model as the default model, which can be customized and replaced according to the section 3 **Custom Model**.
|
||||
|
||||
* detection classification and recognition
|
||||
* detection angle classification and recognition
|
||||
```python
|
||||
from paddleocr import PaddleOCR,draw_ocr
|
||||
# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
|
||||
|
@ -33,7 +35,7 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -67,7 +69,7 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -114,7 +116,7 @@ for line in result:
|
|||
from PIL import Image
|
||||
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, result, txts=None, scores=None, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
@ -163,7 +165,7 @@ Output will be a list, each item contains classification result and confidence
|
|||
['0', 0.99999964]
|
||||
```
|
||||
|
||||
### Use by command line
|
||||
### 2.2 Use by command line
|
||||
|
||||
show help information
|
||||
```bash
|
||||
|
@ -239,11 +241,11 @@ Output will be a list, each item contains classification result and confidence
|
|||
['0', 0.99999964]
|
||||
```
|
||||
|
||||
## Use custom model
|
||||
## 3 Use custom model
|
||||
When the built-in model cannot meet the needs, you need to use your own trained model.
|
||||
First, refer to the first section of [inference_en.md](./inference_en.md) to convert your det and rec model to inference model, and then use it as follows
|
||||
|
||||
### 1. Use by code
|
||||
### 3.1 Use by code
|
||||
|
||||
```python
|
||||
from paddleocr import PaddleOCR,draw_ocr
|
||||
|
@ -260,22 +262,22 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
### Use by command line
|
||||
### 3.2 Use by command line
|
||||
|
||||
```bash
|
||||
paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true
|
||||
```
|
||||
|
||||
### Use web images or numpy array as input
|
||||
## 4 Use web images or numpy array as input
|
||||
|
||||
1. Web image
|
||||
### 4.1 Web image
|
||||
|
||||
Use by code
|
||||
- Use by code
|
||||
```python
|
||||
from paddleocr import PaddleOCR, draw_ocr
|
||||
ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory
|
||||
|
@ -290,16 +292,16 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
Use by command line
|
||||
- Use by command line
|
||||
```bash
|
||||
paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true
|
||||
```
|
||||
|
||||
2. Numpy array
|
||||
### 4.2 Numpy array
|
||||
Support numpy array as input only when used by code
|
||||
|
||||
```python
|
||||
|
@ -318,13 +320,13 @@ image = Image.open(img_path).convert('RGB')
|
|||
boxes = [line[0] for line in result]
|
||||
txts = [line[1][0] for line in result]
|
||||
scores = [line[1][1] for line in result]
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf')
|
||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
```
|
||||
|
||||
|
||||
## Parameter Description
|
||||
## 5 Parameter Description
|
||||
|
||||
| Parameter | Description | Default value |
|
||||
|-------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------------|
|
||||
|
|
After Width: | Height: | Size: 61 KiB |
After Width: | Height: | Size: 663 KiB |
After Width: | Height: | Size: 467 KiB |
After Width: | Height: | Size: 134 KiB |
After Width: | Height: | Size: 337 KiB |
BIN
doc/joinus.PNG
Before Width: | Height: | Size: 111 KiB After Width: | Height: | Size: 109 KiB |
After Width: | Height: | Size: 242 KiB |
21
paddleocr.py
|
@ -146,7 +146,8 @@ def parse_args(mMain=True, add_help=True):
|
|||
# DB parmas
|
||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0)
|
||||
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
|
||||
parser.add_argument("--use_dilation", type=bool, default=False)
|
||||
|
||||
# EAST parmas
|
||||
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
|
||||
|
@ -193,7 +194,8 @@ def parse_args(mMain=True, add_help=True):
|
|||
det_limit_type='max',
|
||||
det_db_thresh=0.3,
|
||||
det_db_box_thresh=0.5,
|
||||
det_db_unclip_ratio=2.0,
|
||||
det_db_unclip_ratio=1.6,
|
||||
use_dilation=False,
|
||||
det_east_score_thresh=0.8,
|
||||
det_east_cover_thresh=0.1,
|
||||
det_east_nms_thresh=0.2,
|
||||
|
@ -234,7 +236,9 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
use_inner_dict = False
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
use_inner_dict = True
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
|
||||
|
@ -261,9 +265,9 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||
sys.exit(0)
|
||||
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
if use_inner_dict:
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
|
||||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
|
@ -280,8 +284,13 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
if isinstance(img, list) and det == True:
|
||||
logger.error('When input a list of images, det must be false')
|
||||
exit(0)
|
||||
if cls == False:
|
||||
self.use_angle_cls = False
|
||||
elif cls == True and self.use_angle_cls == False:
|
||||
logger.warning(
|
||||
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
|
||||
)
|
||||
|
||||
self.use_angle_cls = cls
|
||||
if isinstance(img, str):
|
||||
# download net image
|
||||
if img.startswith('http'):
|
||||
|
|
|
@ -34,6 +34,7 @@ import paddle.distributed as dist
|
|||
from ppocr.data.imaug import transform, create_operators
|
||||
from ppocr.data.simple_dataset import SimpleDataSet
|
||||
from ppocr.data.lmdb_dataset import LMDBDataSet
|
||||
from ppocr.data.pgnet_dataset import PGDataSet
|
||||
|
||||
__all__ = ['build_dataloader', 'transform', 'create_operators']
|
||||
|
||||
|
@ -54,7 +55,7 @@ signal.signal(signal.SIGTERM, term_mp)
|
|||
def build_dataloader(config, mode, device, logger, seed=None):
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet']
|
||||
support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet']
|
||||
module_name = config[mode]['dataset']['name']
|
||||
assert module_name in support_dict, Exception(
|
||||
'DataSet only support {}'.format(support_dict))
|
||||
|
@ -72,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
|
|||
else:
|
||||
use_shared_memory = True
|
||||
if mode == "Train":
|
||||
#Distribute data to multiple cards
|
||||
# Distribute data to multiple cards
|
||||
batch_sampler = DistributedBatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=shuffle,
|
||||
drop_last=drop_last)
|
||||
else:
|
||||
#Distribute data to single card
|
||||
# Distribute data to single card
|
||||
batch_sampler = BatchSampler(
|
||||
dataset=dataset,
|
||||
batch_size=batch_size,
|
||||
|
|
|
@ -28,6 +28,7 @@ from .label_ops import *
|
|||
|
||||
from .east_process import *
|
||||
from .sast_process import *
|
||||
from .pg_process import *
|
||||
|
||||
|
||||
def transform(data, ops=None):
|
||||
|
|
|
@ -187,6 +187,34 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
|||
return dict_character
|
||||
|
||||
|
||||
class E2ELabelEncode(BaseRecLabelEncode):
|
||||
def __init__(self,
|
||||
max_text_length,
|
||||
character_dict_path=None,
|
||||
character_type='EN',
|
||||
use_space_char=False,
|
||||
**kwargs):
|
||||
super(E2ELabelEncode,
|
||||
self).__init__(max_text_length, character_dict_path,
|
||||
character_type, use_space_char)
|
||||
self.pad_num = len(self.dict) # the length to pad
|
||||
|
||||
def __call__(self, data):
|
||||
text_label_index_list, temp_text = [], []
|
||||
texts = data['strs']
|
||||
for text in texts:
|
||||
text = text.lower()
|
||||
temp_text = []
|
||||
for c_ in text:
|
||||
if c_ in self.dict:
|
||||
temp_text.append(self.dict[c_])
|
||||
temp_text = temp_text + [self.pad_num] * (self.max_text_len -
|
||||
len(temp_text))
|
||||
text_label_index_list.append(temp_text)
|
||||
data['strs'] = np.array(text_label_index_list)
|
||||
return data
|
||||
|
||||
|
||||
class AttnLabelEncode(BaseRecLabelEncode):
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
|
@ -215,7 +243,7 @@ class AttnLabelEncode(BaseRecLabelEncode):
|
|||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
|
||||
- len(text) - 1)
|
||||
- len(text) - 2)
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
@ -261,7 +289,7 @@ class SRNLabelEncode(BaseRecLabelEncode):
|
|||
if len(text) > self.max_text_len:
|
||||
return None
|
||||
data['length'] = np.array(len(text))
|
||||
text = text + [char_num] * (self.max_text_len - len(text))
|
||||
text = text + [char_num - 1] * (self.max_text_len - len(text))
|
||||
data['label'] = np.array(text)
|
||||
return data
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ class MakeShrinkMap(object):
|
|||
text_polys, ignore_tags = self.validate_polygons(text_polys,
|
||||
ignore_tags, h, w)
|
||||
gt = np.zeros((h, w), dtype=np.float32)
|
||||
# gt = np.zeros((1, h, w), dtype=np.float32)
|
||||
mask = np.ones((h, w), dtype=np.float32)
|
||||
for i in range(len(text_polys)):
|
||||
polygon = text_polys[i]
|
||||
|
@ -44,21 +43,34 @@ class MakeShrinkMap(object):
|
|||
ignore_tags[i] = True
|
||||
else:
|
||||
polygon_shape = Polygon(polygon)
|
||||
distance = polygon_shape.area * (
|
||||
1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
|
||||
subject = [tuple(l) for l in text_polys[i]]
|
||||
subject = [tuple(l) for l in polygon]
|
||||
padding = pyclipper.PyclipperOffset()
|
||||
padding.AddPath(subject, pyclipper.JT_ROUND,
|
||||
pyclipper.ET_CLOSEDPOLYGON)
|
||||
shrinked = padding.Execute(-distance)
|
||||
shrinked = []
|
||||
|
||||
# Increase the shrink ratio every time we get multiple polygon returned back
|
||||
possible_ratios = np.arange(self.shrink_ratio, 1,
|
||||
self.shrink_ratio)
|
||||
np.append(possible_ratios, 1)
|
||||
# print(possible_ratios)
|
||||
for ratio in possible_ratios:
|
||||
# print(f"Change shrink ratio to {ratio}")
|
||||
distance = polygon_shape.area * (
|
||||
1 - np.power(ratio, 2)) / polygon_shape.length
|
||||
shrinked = padding.Execute(-distance)
|
||||
if len(shrinked) == 1:
|
||||
break
|
||||
|
||||
if shrinked == []:
|
||||
cv2.fillPoly(mask,
|
||||
polygon.astype(np.int32)[np.newaxis, :, :], 0)
|
||||
ignore_tags[i] = True
|
||||
continue
|
||||
shrinked = np.array(shrinked[0]).reshape(-1, 2)
|
||||
cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)
|
||||
# cv2.fillPoly(gt[0], [shrinked.astype(np.int32)], 1)
|
||||
|
||||
for each_shirnk in shrinked:
|
||||
shirnk = np.array(each_shirnk).reshape(-1, 2)
|
||||
cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
|
||||
|
||||
data['shrink_map'] = gt
|
||||
data['shrink_mask'] = mask
|
||||
|
@ -84,11 +96,12 @@ class MakeShrinkMap(object):
|
|||
return polygons, ignore_tags
|
||||
|
||||
def polygon_area(self, polygon):
|
||||
# return cv2.contourArea(polygon.astype(np.float32))
|
||||
edge = 0
|
||||
for i in range(polygon.shape[0]):
|
||||
next_index = (i + 1) % polygon.shape[0]
|
||||
edge += (polygon[next_index, 0] - polygon[i, 0]) * (
|
||||
polygon[next_index, 1] - polygon[i, 1])
|
||||
|
||||
return edge / 2.
|
||||
"""
|
||||
compute polygon area
|
||||
"""
|
||||
area = 0
|
||||
q = polygon[-1]
|
||||
for p in polygon:
|
||||
area += p[0] * q[1] - p[1] * q[0]
|
||||
q = p
|
||||
return area / 2.0
|
||||
|
|
|
@ -185,8 +185,8 @@ class DetResizeForTest(object):
|
|||
resize_h = int(h * ratio)
|
||||
resize_w = int(w * ratio)
|
||||
|
||||
resize_h = int(round(resize_h / 32) * 32)
|
||||
resize_w = int(round(resize_w / 32) * 32)
|
||||
resize_h = max(int(round(resize_h / 32) * 32), 32)
|
||||
resize_w = max(int(round(resize_w / 32) * 32), 32)
|
||||
|
||||
try:
|
||||
if int(resize_w) <= 0 or int(resize_h) <= 0:
|
||||
|
@ -197,7 +197,6 @@ class DetResizeForTest(object):
|
|||
sys.exit(0)
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
# return img, np.array([h, w])
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
def resize_image_type2(self, img):
|
||||
|
@ -206,7 +205,6 @@ class DetResizeForTest(object):
|
|||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(self.resize_long) / resize_h
|
||||
else:
|
||||
|
@ -223,3 +221,72 @@ class DetResizeForTest(object):
|
|||
ratio_w = resize_w / float(w)
|
||||
|
||||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest(object):
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
self.valid_set = kwargs['valid_set']
|
||||
|
||||
def __call__(self, data):
|
||||
img = data['image']
|
||||
src_h, src_w, _ = img.shape
|
||||
if self.valid_set == 'totaltext':
|
||||
im_resized, [ratio_h, ratio_w] = self.resize_image_for_totaltext(
|
||||
img, max_side_len=self.max_side_len)
|
||||
else:
|
||||
im_resized, (ratio_h, ratio_w) = self.resize_image(
|
||||
img, max_side_len=self.max_side_len)
|
||||
data['image'] = im_resized
|
||||
data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])
|
||||
return data
|
||||
|
||||
def resize_image_for_totaltext(self, im, max_side_len=512):
|
||||
|
||||
h, w, _ = im.shape
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
def resize_image(self, im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
# Fix the longer side
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
|
|
@ -0,0 +1,906 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['PGProcessTrain']
|
||||
|
||||
|
||||
class PGProcessTrain(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
tcl_len,
|
||||
batch_size=14,
|
||||
min_crop_size=24,
|
||||
min_text_size=4,
|
||||
max_text_size=512,
|
||||
**kwargs):
|
||||
self.tcl_len = tcl_len
|
||||
self.max_text_length = max_text_length
|
||||
self.max_text_nums = max_text_nums
|
||||
self.batch_size = batch_size
|
||||
self.min_crop_size = min_crop_size
|
||||
self.min_text_size = min_text_size
|
||||
self.max_text_size = max_text_size
|
||||
self.Lexicon_Table = self.get_dict(character_dict_path)
|
||||
self.pad_num = len(self.Lexicon_Table)
|
||||
self.img_id = 0
|
||||
|
||||
def get_dict(self, character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
def quad_area(self, poly):
|
||||
"""
|
||||
compute area of a polygon
|
||||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad
|
||||
|
||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
||||
"""
|
||||
check so that the text poly is in the same direction,
|
||||
and also filter some invalid polygons
|
||||
:param polys:
|
||||
:param tags:
|
||||
:return:
|
||||
"""
|
||||
(h, w) = xxx_todo_changeme
|
||||
if polys.shape[0] == 0:
|
||||
return polys, np.array([]), np.array([])
|
||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||
polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
|
||||
|
||||
validated_polys = []
|
||||
validated_tags = []
|
||||
hv_tags = []
|
||||
for poly, tag in zip(polys, tags):
|
||||
quad = self.gen_quad_from_poly(poly)
|
||||
p_area = self.quad_area(quad)
|
||||
if abs(p_area) < 1:
|
||||
print('invalid poly')
|
||||
continue
|
||||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
||||
1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
||||
quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
||||
quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(
|
||||
hv_tags)
|
||||
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
hv_tags,
|
||||
txts,
|
||||
crop_background=False,
|
||||
max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
:param polys: [b,4,2]
|
||||
:param tags:
|
||||
:param crop_background:
|
||||
:param max_tries: 50 -> 25
|
||||
:return:
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
pad_h = h // 10
|
||||
pad_w = w // 10
|
||||
h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
|
||||
w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
|
||||
for poly in polys:
|
||||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
if len(h_axis) == 0 or len(w_axis) == 0:
|
||||
return im, polys, tags, hv_tags, txts
|
||||
for i in range(max_tries):
|
||||
xx = np.random.choice(w_axis, size=2)
|
||||
xmin = np.min(xx) - pad_w
|
||||
xmax = np.max(xx) - pad_w
|
||||
xmin = np.clip(xmin, 0, w - 1)
|
||||
xmax = np.clip(xmax, 0, w - 1)
|
||||
yy = np.random.choice(h_axis, size=2)
|
||||
ymin = np.min(yy) - pad_h
|
||||
ymax = np.max(yy) - pad_h
|
||||
ymin = np.clip(ymin, 0, h - 1)
|
||||
ymax = np.clip(ymax, 0, h - 1)
|
||||
if xmax - xmin < self.min_crop_size or \
|
||||
ymax - ymin < self.min_crop_size:
|
||||
continue
|
||||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
return im[ymin: ymax + 1, xmin: xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
else:
|
||||
continue
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
txts_tmp = []
|
||||
for selected_poly in selected_polys:
|
||||
txts_tmp.append(txts[selected_poly])
|
||||
txts = txts_tmp
|
||||
polys[:, :, 0] -= xmin
|
||||
polys[:, :, 1] -= ymin
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
return im, polys, tags, hv_tags, txts
|
||||
|
||||
def fit_and_gather_tcl_points_v2(self,
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h,
|
||||
max_w,
|
||||
fixed_point_num=64,
|
||||
img_id=0,
|
||||
reference_height=3):
|
||||
"""
|
||||
Find the center point of poly as key_points, then fit and gather.
|
||||
"""
|
||||
key_point_xys = []
|
||||
point_num = poly.shape[0]
|
||||
for idx in range(point_num // 2):
|
||||
center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
|
||||
key_point_xys.append(center_point)
|
||||
|
||||
tmp_image = np.zeros(
|
||||
shape=(
|
||||
max_h,
|
||||
max_w, ), dtype='float32')
|
||||
cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
|
||||
False, 1.0)
|
||||
ys, xs = np.where(tmp_image > 0)
|
||||
xy_text = np.array(list(zip(xs, ys)), dtype='float32')
|
||||
|
||||
left_center_pt = (
|
||||
(min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
|
||||
right_center_pt = (
|
||||
(min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / (
|
||||
np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_unit_vec_tile = np.tile(proj_unit_vec,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
left_center_pt_tile = np.tile(left_center_pt,
|
||||
(xy_text.shape[0], 1)) # (n, 2)
|
||||
xy_text_to_left_center = xy_text - left_center_pt_tile
|
||||
proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
|
||||
xy_text = xy_text[np.argsort(proj_value)]
|
||||
|
||||
# convert to np and keep the num of point not greater then fixed_point_num
|
||||
pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
|
||||
point_num = len(pos_info)
|
||||
if point_num > fixed_point_num:
|
||||
keep_ids = [
|
||||
int((point_num * 1.0 / fixed_point_num) * x)
|
||||
for x in range(fixed_point_num)
|
||||
]
|
||||
pos_info = pos_info[keep_ids, :]
|
||||
|
||||
keep = int(min(len(pos_info), fixed_point_num))
|
||||
if np.random.rand() < 0.2 and reference_height >= 3:
|
||||
dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
|
||||
random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
|
||||
[keep, 1])
|
||||
pos_info += random_float
|
||||
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
|
||||
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
|
||||
|
||||
# padding to fixed length
|
||||
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
|
||||
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
|
||||
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
|
||||
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
|
||||
pos_m[:keep] = 1.0
|
||||
return pos_l, pos_m
|
||||
|
||||
def generate_direction_map(self, poly_quads, n_char, direction_map):
|
||||
"""
|
||||
"""
|
||||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / n_char, 1.0)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
k = 1
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = (
|
||||
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (
|
||||
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(
|
||||
map(float,
|
||||
[direct_vector[0], direct_vector[1], 1.0 / average_height]))
|
||||
cv2.fillPoly(direction_map,
|
||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||
direction_label)
|
||||
k += 1
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
"""
|
||||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_ctc_label(self,
|
||||
h,
|
||||
w,
|
||||
polys,
|
||||
tags,
|
||||
text_strs,
|
||||
ds_ratio,
|
||||
tcl_ratio=0.3,
|
||||
shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
score_map_big = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
score_label_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
||||
[1, 1, 3]).astype(np.float32)
|
||||
|
||||
label_idx = 0
|
||||
score_label_map_text_label_list = []
|
||||
pos_list, pos_mask, label_list = [], [], []
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
continue
|
||||
|
||||
if tag:
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
text_label = text_strs[poly_idx]
|
||||
text_label = self.prepare_text_label(text_label,
|
||||
self.Lexicon_Table)
|
||||
|
||||
text_label_index_list = [[self.Lexicon_Table.index(c_)]
|
||||
for c_ in text_label
|
||||
if c_ in self.Lexicon_Table]
|
||||
if len(text_label_index_list) < 1:
|
||||
continue
|
||||
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(
|
||||
tcl_quads,
|
||||
shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
|
||||
cv2.fillPoly(score_map,
|
||||
np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
cv2.fillPoly(score_map_big,
|
||||
np.round(stcl_quads / ds_ratio).astype(np.int32),
|
||||
1.0)
|
||||
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(
|
||||
quad_mask,
|
||||
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
||||
quad_mask, tbo_map)
|
||||
|
||||
# score label map and score_label_map_text_label_list for refine
|
||||
if label_idx == 0:
|
||||
text_pos_list_ = [[len(self.Lexicon_Table)], ]
|
||||
score_label_map_text_label_list.append(text_pos_list_)
|
||||
|
||||
label_idx += 1
|
||||
cv2.fillPoly(score_label_map,
|
||||
np.round(poly_quads).astype(np.int32), label_idx)
|
||||
score_label_map_text_label_list.append(text_label_index_list)
|
||||
|
||||
# direction info, fix-me
|
||||
n_char = len(text_label_index_list)
|
||||
direction_map = self.generate_direction_map(poly_quads, n_char,
|
||||
direction_map)
|
||||
|
||||
# pos info
|
||||
average_shrink_height = self.calculate_average_height(
|
||||
stcl_quads)
|
||||
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
|
||||
min_area_quad,
|
||||
poly,
|
||||
max_h=h,
|
||||
max_w=w,
|
||||
fixed_point_num=64,
|
||||
img_id=self.img_id,
|
||||
reference_height=average_shrink_height)
|
||||
|
||||
label_l = text_label_index_list
|
||||
if len(text_label_index_list) < 2:
|
||||
continue
|
||||
|
||||
pos_list.append(pos_l)
|
||||
pos_mask.append(pos_m)
|
||||
label_list.append(label_l)
|
||||
|
||||
# use big score_map for smooth tcl lines
|
||||
score_map_big_resized = cv2.resize(
|
||||
score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
|
||||
score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
|
||||
|
||||
return score_map, score_label_map, tbo_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label_list
|
||||
|
||||
def adjust_point(self, poly):
|
||||
"""
|
||||
adjust point order.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
if point_num == 4:
|
||||
len_1 = np.linalg.norm(poly[0] - poly[1])
|
||||
len_2 = np.linalg.norm(poly[1] - poly[2])
|
||||
len_3 = np.linalg.norm(poly[2] - poly[3])
|
||||
len_4 = np.linalg.norm(poly[3] - poly[0])
|
||||
|
||||
if (len_1 + len_3) * 1.5 < (len_2 + len_4):
|
||||
poly = poly[[1, 2, 3, 0], :]
|
||||
|
||||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (
|
||||
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
index = list(range(1, point_num)) + [0]
|
||||
poly = poly[np.array(index), :]
|
||||
return poly
|
||||
|
||||
def gen_min_area_quad_from_poly(self, poly):
|
||||
"""
|
||||
Generate min area quad from poly.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if point_num == 4:
|
||||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
first_point_idx = 0
|
||||
min_dist = 1e4
|
||||
for i in range(4):
|
||||
dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
|
||||
np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
|
||||
np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
|
||||
np.linalg.norm(box[(i + 3) % 4] - poly[-1])
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
first_point_idx = i
|
||||
|
||||
for i in range(4):
|
||||
min_area_quad[i] = box[(first_point_idx + i) % 4]
|
||||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self,
|
||||
quads,
|
||||
shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
upper_edge_list = []
|
||||
|
||||
def get_cut_info(edge_len_list, cut_len):
|
||||
for idx, edge_len in enumerate(edge_len_list):
|
||||
cut_len -= edge_len
|
||||
if cut_len <= 0.000001:
|
||||
ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
|
||||
return idx, ratio
|
||||
|
||||
for quad in quads:
|
||||
upper_edge_len = np.linalg.norm(quad[0] - quad[1])
|
||||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
||||
3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
||||
2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length,
|
||||
sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(
|
||||
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(
|
||||
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append(
|
||||
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
out_quad_list.append(quads[idx])
|
||||
out_quad_list.append(right_quad)
|
||||
|
||||
return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
|
||||
|
||||
def prepare_text_label(self, label_str, Lexicon_Table):
|
||||
"""
|
||||
Prepare text lablel by given Lexicon_Table.
|
||||
"""
|
||||
if len(Lexicon_Table) == 36:
|
||||
return label_str.lower()
|
||||
else:
|
||||
return label_str
|
||||
|
||||
def vector_angle(self, A, B):
|
||||
"""
|
||||
Calculate the angle between vector AB and x-axis positive direction.
|
||||
"""
|
||||
AB = np.array([B[1] - A[1], B[0] - A[0]])
|
||||
return np.arctan2(*AB)
|
||||
|
||||
def theta_line_cross_point(self, theta, point):
|
||||
"""
|
||||
Calculate the line through given point and angle in ax + by + c =0 form.
|
||||
"""
|
||||
x, y = point
|
||||
cos = np.cos(theta)
|
||||
sin = np.sin(theta)
|
||||
return [sin, -cos, cos * y - sin * x]
|
||||
|
||||
def line_cross_two_point(self, A, B):
|
||||
"""
|
||||
Calculate the line through given point A and B in ax + by + c =0 form.
|
||||
"""
|
||||
angle = self.vector_angle(A, B)
|
||||
return self.theta_line_cross_point(angle, A)
|
||||
|
||||
def average_angle(self, poly):
|
||||
"""
|
||||
Calculate the average angle between left and right edge in given poly.
|
||||
"""
|
||||
p0, p1, p2, p3 = poly
|
||||
angle30 = self.vector_angle(p3, p0)
|
||||
angle21 = self.vector_angle(p2, p1)
|
||||
return (angle30 + angle21) / 2
|
||||
|
||||
def line_cross_point(self, line1, line2):
|
||||
"""
|
||||
line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
|
||||
"""
|
||||
a1, b1, c1 = line1
|
||||
a2, b2, c2 = line2
|
||||
d = a1 * b2 - a2 * b1
|
||||
|
||||
if d == 0:
|
||||
print('Cross point does not exist')
|
||||
return np.array([0, 0], dtype=np.float32)
|
||||
else:
|
||||
x = (b1 * c2 - b2 * c1) / d
|
||||
y = (a2 * c1 - a1 * c2) / d
|
||||
|
||||
return np.array([x, y], dtype=np.float32)
|
||||
|
||||
def quad2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
||||
def poly2tcl(self, poly, ratio):
|
||||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
||||
) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
||||
def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
|
||||
"""
|
||||
Generate tbo_map for give quad.
|
||||
"""
|
||||
# upper and lower line function: ax + by + c = 0;
|
||||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
||||
xy_in_poly = np.argwhere(tcl_mask == 1)
|
||||
for y, x in xy_in_poly:
|
||||
point = (x, y)
|
||||
line = self.theta_line_cross_point(angle, point)
|
||||
cross_point_upper = self.line_cross_point(up_line, line)
|
||||
cross_point_lower = self.line_cross_point(lower_line, line)
|
||||
##FIX, offset reverse
|
||||
upper_offset_x, upper_offset_y = cross_point_upper - point
|
||||
lower_offset_x, lower_offset_y = cross_point_lower - point
|
||||
tbo_map[y, x, 0] = upper_offset_y
|
||||
tbo_map[y, x, 1] = upper_offset_x
|
||||
tbo_map[y, x, 2] = lower_offset_y
|
||||
tbo_map[y, x, 3] = lower_offset_x
|
||||
tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
|
||||
return tbo_map
|
||||
|
||||
def poly2quads(self, poly):
|
||||
"""
|
||||
Split poly into quads.
|
||||
"""
|
||||
quad_list = []
|
||||
point_num = poly.shape[0]
|
||||
|
||||
# point pair
|
||||
point_pair_list = []
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = [poly[idx], poly[point_num - 1 - idx]]
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
||||
).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
def rotate_im_poly(self, im, text_polys):
|
||||
"""
|
||||
rotate image with 90 / 180 / 270 degre
|
||||
"""
|
||||
im_w, im_h = im.shape[1], im.shape[0]
|
||||
dst_im = im.copy()
|
||||
dst_polys = []
|
||||
rand_degree_ratio = np.random.rand()
|
||||
rand_degree_cnt = 1
|
||||
if rand_degree_ratio > 0.5:
|
||||
rand_degree_cnt = 3
|
||||
for i in range(rand_degree_cnt):
|
||||
dst_im = np.rot90(dst_im)
|
||||
rot_degree = -90 * rand_degree_cnt
|
||||
rot_angle = rot_degree * math.pi / 180.0
|
||||
n_poly = text_polys.shape[0]
|
||||
cx, cy = 0.5 * im_w, 0.5 * im_h
|
||||
ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
|
||||
for i in range(n_poly):
|
||||
wordBB = text_polys[i]
|
||||
poly = []
|
||||
for j in range(4): # 16->4
|
||||
sx, sy = wordBB[j][0], wordBB[j][1]
|
||||
dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
|
||||
sy - cy) + ncx
|
||||
dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
|
||||
sy - cy) + ncy
|
||||
poly.append([dx, dy])
|
||||
dst_polys.append(poly)
|
||||
return dst_im, np.array(dst_polys, dtype=np.float32)
|
||||
|
||||
def __call__(self, data):
|
||||
input_size = 512
|
||||
im = data['image']
|
||||
text_polys = data['polys']
|
||||
text_tags = data['tags']
|
||||
text_strs = data['strs']
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
if text_polys.shape[0] <= 0:
|
||||
return None
|
||||
# set aspect ratio and keep area fix
|
||||
asp_scales = np.arange(1.0, 1.55, 0.1)
|
||||
asp_scale = np.random.choice(asp_scales)
|
||||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
text_polys[:, :, 0] *= asp_wx
|
||||
text_polys[:, :, 1] *= asp_hy
|
||||
|
||||
h, w, _ = im.shape
|
||||
if max(h, w) > 2048:
|
||||
rd_scale = 2048.0 / max(h, w)
|
||||
im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
|
||||
text_polys *= rd_scale
|
||||
h, w, _ = im.shape
|
||||
if min(h, w) < 16:
|
||||
return None
|
||||
|
||||
# no background
|
||||
im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
|
||||
im,
|
||||
text_polys,
|
||||
text_tags,
|
||||
hv_tags,
|
||||
text_strs,
|
||||
crop_background=False)
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
# # continue for all ignore case
|
||||
if np.sum((text_tags * 1.0)) >= text_tags.size:
|
||||
return None
|
||||
new_h, new_w, _ = im.shape
|
||||
if (new_h is None) or (new_w is None):
|
||||
return None
|
||||
# resize image
|
||||
std_ratio = float(input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array(
|
||||
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
# add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks / 2) * 2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
# add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
# add darker
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < input_size * 0.5:
|
||||
return None
|
||||
im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
||||
# Random the start position
|
||||
del_h = input_size - new_h
|
||||
del_w = input_size - new_w
|
||||
sh, sw = 0, 0
|
||||
if del_h > 1:
|
||||
sh = int(np.random.rand() * del_h)
|
||||
if del_w > 1:
|
||||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, score_label_map, border_map, direction_map, training_mask, \
|
||||
pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
|
||||
input_size,
|
||||
text_polys,
|
||||
text_tags,
|
||||
text_strs, 0.25)
|
||||
if len(label_list) <= 0: # eliminate negative samples
|
||||
return None
|
||||
pos_list_temp = np.zeros([64, 3])
|
||||
pos_mask_temp = np.zeros([64, 1])
|
||||
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
|
||||
|
||||
for i, label in enumerate(label_list):
|
||||
n = len(label)
|
||||
if n > self.max_text_length:
|
||||
label_list[i] = label[:self.max_text_length]
|
||||
continue
|
||||
while n < self.max_text_length:
|
||||
label.append([self.pad_num])
|
||||
n += 1
|
||||
|
||||
for i in range(len(label_list)):
|
||||
label_list[i] = np.array(label_list[i])
|
||||
|
||||
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
|
||||
return None
|
||||
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
|
||||
pos_list.append(pos_list_temp)
|
||||
pos_mask.append(pos_mask_temp)
|
||||
label_list.append(label_list_temp)
|
||||
|
||||
if self.img_id == self.batch_size - 1:
|
||||
self.img_id = 0
|
||||
else:
|
||||
self.img_id += 1
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
images = im_padded[::-1, :, :]
|
||||
tcl_maps = score_map[np.newaxis, :, :]
|
||||
tcl_label_maps = score_label_map[np.newaxis, :, :]
|
||||
border_maps = border_map.transpose((2, 0, 1))
|
||||
direction_maps = direction_map.transpose((2, 0, 1))
|
||||
training_masks = training_mask[np.newaxis, :, :]
|
||||
pos_list = np.array(pos_list)
|
||||
pos_mask = np.array(pos_mask)
|
||||
label_list = np.array(label_list)
|
||||
data['images'] = images
|
||||
data['tcl_maps'] = tcl_maps
|
||||
data['tcl_label_maps'] = tcl_label_maps
|
||||
data['border_maps'] = border_maps
|
||||
data['direction_maps'] = direction_maps
|
||||
data['training_masks'] = training_masks
|
||||
data['label_list'] = label_list
|
||||
data['pos_list'] = pos_list
|
||||
data['pos_mask'] = pos_mask
|
||||
return data
|
|
@ -117,13 +117,16 @@ class RawRandAugment(object):
|
|||
class RandAugment(RawRandAugment):
|
||||
""" RandAugment wrapper to auto fit different img types """
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, prob=0.5, *args, **kwargs):
|
||||
self.prob = prob
|
||||
if six.PY2:
|
||||
super(RandAugment, self).__init__(*args, **kwargs)
|
||||
else:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __call__(self, data):
|
||||
if np.random.rand() > self.prob:
|
||||
return data
|
||||
img = data['image']
|
||||
if not isinstance(img, Image.Image):
|
||||
img = np.ascontiguousarray(img)
|
||||
|
|
|
@ -0,0 +1,175 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import os
|
||||
from paddle.io import Dataset
|
||||
from .imaug import transform, create_operators
|
||||
import random
|
||||
|
||||
|
||||
class PGDataSet(Dataset):
|
||||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(PGDataSet, self).__init__()
|
||||
|
||||
self.logger = logger
|
||||
self.seed = seed
|
||||
self.mode = mode
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
loader_config = config[mode]['loader']
|
||||
|
||||
label_file_list = dataset_config.pop('label_file_list')
|
||||
data_source_num = len(label_file_list)
|
||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||
if isinstance(ratio_list, (float, int)):
|
||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||
self.data_format = dataset_config.get('data_format', 'icdar')
|
||||
assert len(
|
||||
ratio_list
|
||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||
self.do_shuffle = loader_config['shuffle']
|
||||
|
||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
||||
self.data_format)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
self.shuffle_data_random()
|
||||
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def extract_polys(self, poly_txt_path):
|
||||
"""
|
||||
Read text_polys, txt_tags, txts from give txt file.
|
||||
"""
|
||||
text_polys, txt_tags, txts = [], [], []
|
||||
with open(poly_txt_path) as f:
|
||||
for line in f.readlines():
|
||||
poly_str, txt = line.strip().split('\t')
|
||||
poly = list(map(float, poly_str.split(',')))
|
||||
if self.mode.lower() == "eval":
|
||||
while len(poly) < 100:
|
||||
poly.append(-1)
|
||||
text_polys.append(
|
||||
np.array(
|
||||
poly, dtype=np.float32).reshape(-1, 2))
|
||||
txts.append(txt)
|
||||
txt_tags.append(txt == '###')
|
||||
|
||||
return np.array(list(map(np.array, text_polys))), \
|
||||
np.array(txt_tags, dtype=np.bool), txts
|
||||
|
||||
def extract_info_textnet(self, im_fn, img_dir=''):
|
||||
"""
|
||||
Extract information from line in textnet format.
|
||||
"""
|
||||
info_list = im_fn.split('\t')
|
||||
img_path = ''
|
||||
for ext in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
||||
]:
|
||||
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
|
||||
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
|
||||
break
|
||||
|
||||
if img_path == '':
|
||||
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
|
||||
info_list[0], img_dir))
|
||||
|
||||
nBox = (len(info_list) - 1) // 9
|
||||
wordBBs, txts, txt_tags = [], [], []
|
||||
for n in range(0, nBox):
|
||||
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
|
||||
txt = info_list[(n + 1) * 9]
|
||||
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
|
||||
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
|
||||
txts.append(txt)
|
||||
if txt == '###':
|
||||
txt_tags.append(True)
|
||||
else:
|
||||
txt_tags.append(False)
|
||||
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
|
||||
|
||||
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
|
||||
if isinstance(file_list, str):
|
||||
file_list = [file_list]
|
||||
data_lines = []
|
||||
for idx, data_source in enumerate(file_list):
|
||||
image_files = []
|
||||
if data_format == 'icdar':
|
||||
image_files = [(data_source, x) for x in
|
||||
os.listdir(os.path.join(data_source, 'rgb'))
|
||||
if x.split('.')[-1] in [
|
||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
|
||||
'tiff', 'gif', 'JPG'
|
||||
]]
|
||||
elif data_format == 'textnet':
|
||||
with open(data_source) as f:
|
||||
image_files = [(data_source, x.strip())
|
||||
for x in f.readlines()]
|
||||
else:
|
||||
print("Unrecognized data format...")
|
||||
exit(-1)
|
||||
random.seed(self.seed)
|
||||
image_files = random.sample(
|
||||
image_files, round(len(image_files) * ratio_list[idx]))
|
||||
data_lines.extend(image_files)
|
||||
return data_lines
|
||||
|
||||
def __getitem__(self, idx):
|
||||
file_idx = self.data_idx_order_list[idx]
|
||||
data_path, data_line = self.data_lines[file_idx]
|
||||
try:
|
||||
if self.data_format == 'icdar':
|
||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
||||
if self.mode.lower() == "eval":
|
||||
poly_path = os.path.join(data_path, 'poly_gt',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
else:
|
||||
poly_path = os.path.join(data_path, 'poly',
|
||||
data_line.split('.')[0] + '.txt')
|
||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
||||
else:
|
||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
||||
data_line, image_dir)
|
||||
|
||||
data = {
|
||||
'img_path': im_path,
|
||||
'polys': text_polys,
|
||||
'tags': text_tags,
|
||||
'strs': text_strs
|
||||
}
|
||||
with open(data['img_path'], 'rb') as f:
|
||||
img = f.read()
|
||||
data['image'] = img
|
||||
outs = transform(data, self.ops)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
"When parsing line {}, error happened with msg: {}".format(
|
||||
self.data_idx_order_list[idx], e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_idx_order_list)
|
|
@ -23,6 +23,7 @@ class SimpleDataSet(Dataset):
|
|||
def __init__(self, config, mode, logger, seed=None):
|
||||
super(SimpleDataSet, self).__init__()
|
||||
self.logger = logger
|
||||
self.mode = mode.lower()
|
||||
|
||||
global_config = config['Global']
|
||||
dataset_config = config[mode]['dataset']
|
||||
|
@ -45,7 +46,7 @@ class SimpleDataSet(Dataset):
|
|||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||
if mode.lower() == "train":
|
||||
if self.mode == "train" and self.do_shuffle:
|
||||
self.shuffle_data_random()
|
||||
self.ops = create_operators(dataset_config['transforms'], global_config)
|
||||
|
||||
|
@ -56,16 +57,16 @@ class SimpleDataSet(Dataset):
|
|||
for idx, file in enumerate(file_list):
|
||||
with open(file, "rb") as f:
|
||||
lines = f.readlines()
|
||||
random.seed(self.seed)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||
random.seed(self.seed)
|
||||
lines = random.sample(lines,
|
||||
round(len(lines) * ratio_list[idx]))
|
||||
data_lines.extend(lines)
|
||||
return data_lines
|
||||
|
||||
def shuffle_data_random(self):
|
||||
if self.do_shuffle:
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
random.seed(self.seed)
|
||||
random.shuffle(self.data_lines)
|
||||
return
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
@ -90,7 +91,10 @@ class SimpleDataSet(Dataset):
|
|||
data_line, e))
|
||||
outs = None
|
||||
if outs is None:
|
||||
return self.__getitem__(np.random.randint(self.__len__()))
|
||||
# during evaluation, we should fix the idx to get same results for many times of evaluation.
|
||||
rnd_idx = np.random.randint(self.__len__(
|
||||
)) if self.mode == "train" else (idx + 1) % self.__len__()
|
||||
return self.__getitem__(rnd_idx)
|
||||
return outs
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -29,10 +29,11 @@ def build_loss(config):
|
|||
# cls loss
|
||||
from .cls_loss import ClsLoss
|
||||
|
||||
# e2e loss
|
||||
from .e2e_pg_loss import PGLoss
|
||||
support_dict = [
|
||||
'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
|
||||
'SRNLoss'
|
||||
]
|
||||
'SRNLoss', 'PGLoss']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -200,6 +200,6 @@ def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
|
|||
i, :, :], ohem_ratio))
|
||||
|
||||
selected_masks = np.concatenate(selected_masks, 0)
|
||||
selected_masks = paddle.to_variable(selected_masks)
|
||||
selected_masks = paddle.to_tensor(selected_masks)
|
||||
|
||||
return selected_masks
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from paddle import nn
|
||||
import paddle
|
||||
|
||||
from .det_basic_loss import DiceLoss
|
||||
from ppocr.utils.e2e_utils.extract_batchsize import pre_process
|
||||
|
||||
|
||||
class PGLoss(nn.Layer):
|
||||
def __init__(self,
|
||||
tcl_bs,
|
||||
max_text_length,
|
||||
max_text_nums,
|
||||
pad_num,
|
||||
eps=1e-6,
|
||||
**kwargs):
|
||||
super(PGLoss, self).__init__()
|
||||
self.tcl_bs = tcl_bs
|
||||
self.max_text_nums = max_text_nums
|
||||
self.max_text_length = max_text_length
|
||||
self.pad_num = pad_num
|
||||
self.dice_loss = DiceLoss(eps=eps)
|
||||
|
||||
def border_loss(self, f_border, l_border, l_score, l_mask):
|
||||
l_border_split, l_border_norm = paddle.tensor.split(
|
||||
l_border, num_or_sections=[4, 1], axis=1)
|
||||
f_border_split = f_border
|
||||
b, c, h, w = l_border_norm.shape
|
||||
l_border_norm_split = paddle.expand(
|
||||
x=l_border_norm, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
|
||||
border_diff = l_border_split - f_border_split
|
||||
abs_border_diff = paddle.abs(border_diff)
|
||||
border_sign = abs_border_diff < 1.0
|
||||
border_sign = paddle.cast(border_sign, dtype='float32')
|
||||
border_sign.stop_gradient = True
|
||||
border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
|
||||
(abs_border_diff - 0.5) * (1.0 - border_sign)
|
||||
border_out_loss = l_border_norm_split * border_in_loss
|
||||
border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
|
||||
(paddle.sum(l_border_score * l_border_mask) + 1e-5)
|
||||
return border_loss
|
||||
|
||||
def direction_loss(self, f_direction, l_direction, l_score, l_mask):
|
||||
l_direction_split, l_direction_norm = paddle.tensor.split(
|
||||
l_direction, num_or_sections=[2, 1], axis=1)
|
||||
f_direction_split = f_direction
|
||||
b, c, h, w = l_direction_norm.shape
|
||||
l_direction_norm_split = paddle.expand(
|
||||
x=l_direction_norm, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_score.shape
|
||||
l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
|
||||
b, c, h, w = l_mask.shape
|
||||
l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
|
||||
direction_diff = l_direction_split - f_direction_split
|
||||
abs_direction_diff = paddle.abs(direction_diff)
|
||||
direction_sign = abs_direction_diff < 1.0
|
||||
direction_sign = paddle.cast(direction_sign, dtype='float32')
|
||||
direction_sign.stop_gradient = True
|
||||
direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
|
||||
(abs_direction_diff - 0.5) * (1.0 - direction_sign)
|
||||
direction_out_loss = l_direction_norm_split * direction_in_loss
|
||||
direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
|
||||
(paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
|
||||
return direction_loss
|
||||
|
||||
def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
|
||||
f_char = paddle.transpose(f_char, [0, 2, 3, 1])
|
||||
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
|
||||
tcl_pos = paddle.cast(tcl_pos, dtype=int)
|
||||
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
|
||||
f_tcl_char = paddle.reshape(f_tcl_char,
|
||||
[-1, 64, 37]) # len(Lexicon_Table)+1
|
||||
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
|
||||
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
|
||||
b, c, l = tcl_mask.shape
|
||||
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
|
||||
tcl_mask_fg.stop_gradient = True
|
||||
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
|
||||
-20.0)
|
||||
f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
|
||||
f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
|
||||
N, B, _ = f_tcl_char_ld.shape
|
||||
input_lengths = paddle.to_tensor([N] * B, dtype='int64')
|
||||
cost = paddle.nn.functional.ctc_loss(
|
||||
log_probs=f_tcl_char_ld,
|
||||
labels=tcl_label,
|
||||
input_lengths=input_lengths,
|
||||
label_lengths=label_t,
|
||||
blank=self.pad_num,
|
||||
reduction='none')
|
||||
cost = cost.mean()
|
||||
return cost
|
||||
|
||||
def forward(self, predicts, labels):
|
||||
images, tcl_maps, tcl_label_maps, border_maps \
|
||||
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
|
||||
# for all the batch_size
|
||||
pos_list, pos_mask, label_list, label_t = pre_process(
|
||||
label_list, pos_list, pos_mask, self.max_text_length,
|
||||
self.max_text_nums, self.pad_num, self.tcl_bs)
|
||||
|
||||
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
|
||||
predicts['f_char']
|
||||
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
|
||||
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
|
||||
training_masks)
|
||||
direction_loss = self.direction_loss(f_direction, direction_maps,
|
||||
tcl_maps, training_masks)
|
||||
ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
|
||||
loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
|
||||
|
||||
losses = {
|
||||
'loss': loss_all,
|
||||
"score_loss": score_loss,
|
||||
"border_loss": border_loss,
|
||||
"direction_loss": direction_loss,
|
||||
"ctc_loss": ctc_loss
|
||||
}
|
||||
return losses
|
|
@ -26,8 +26,9 @@ def build_metric(config):
|
|||
from .det_metric import DetMetric
|
||||
from .rec_metric import RecMetric
|
||||
from .cls_metric import ClsMetric
|
||||
from .e2e_metric import E2EMetric
|
||||
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric']
|
||||
support_dict = ['DetMetric', 'RecMetric', 'ClsMetric', 'E2EMetric']
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
module_name = config.pop('name')
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
__all__ = ['E2EMetric']
|
||||
|
||||
from ppocr.utils.e2e_metric.Deteval import get_socre, combine_results
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import get_dict
|
||||
|
||||
|
||||
class E2EMetric(object):
|
||||
def __init__(self,
|
||||
character_dict_path,
|
||||
main_indicator='f_score_e2e',
|
||||
**kwargs):
|
||||
self.label_list = get_dict(character_dict_path)
|
||||
self.max_index = len(self.label_list)
|
||||
self.main_indicator = main_indicator
|
||||
self.reset()
|
||||
|
||||
def __call__(self, preds, batch, **kwargs):
|
||||
temp_gt_polyons_batch = batch[2]
|
||||
temp_gt_strs_batch = batch[3]
|
||||
ignore_tags_batch = batch[4]
|
||||
gt_polyons_batch = []
|
||||
gt_strs_batch = []
|
||||
|
||||
temp_gt_polyons_batch = temp_gt_polyons_batch[0].tolist()
|
||||
for temp_list in temp_gt_polyons_batch:
|
||||
t = []
|
||||
for index in temp_list:
|
||||
if index[0] != -1 and index[1] != -1:
|
||||
t.append(index)
|
||||
gt_polyons_batch.append(t)
|
||||
|
||||
temp_gt_strs_batch = temp_gt_strs_batch[0].tolist()
|
||||
for temp_list in temp_gt_strs_batch:
|
||||
t = ""
|
||||
for index in temp_list:
|
||||
if index < self.max_index:
|
||||
t += self.label_list[index]
|
||||
gt_strs_batch.append(t)
|
||||
|
||||
for pred, gt_polyons, gt_strs, ignore_tags in zip(
|
||||
[preds], [gt_polyons_batch], [gt_strs_batch], ignore_tags_batch):
|
||||
# prepare gt
|
||||
gt_info_list = [{
|
||||
'points': gt_polyon,
|
||||
'text': gt_str,
|
||||
'ignore': ignore_tag
|
||||
} for gt_polyon, gt_str, ignore_tag in
|
||||
zip(gt_polyons, gt_strs, ignore_tags)]
|
||||
# prepare det
|
||||
e2e_info_list = [{
|
||||
'points': det_polyon,
|
||||
'text': pred_str
|
||||
} for det_polyon, pred_str in zip(pred['points'], pred['strs'])]
|
||||
result = get_socre(gt_info_list, e2e_info_list)
|
||||
self.results.append(result)
|
||||
|
||||
def get_metric(self):
|
||||
metircs = combine_results(self.results)
|
||||
self.reset()
|
||||
return metircs
|
||||
|
||||
def reset(self):
|
||||
self.results = [] # clear results
|
|
@ -150,7 +150,7 @@ class DetectionIoUEvaluator(object):
|
|||
pairs.append({'gt': gtNum, 'det': detNum})
|
||||
detMatchedNums.append(detNum)
|
||||
evaluationLog += "Match GT #" + \
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
str(gtNum) + " with Det #" + str(detNum) + "\n"
|
||||
|
||||
numGtCare = (len(gtPols) - len(gtDontCarePolsNum))
|
||||
numDetCare = (len(detPols) - len(detDontCarePolsNum))
|
||||
|
@ -162,7 +162,7 @@ class DetectionIoUEvaluator(object):
|
|||
precision = 0 if numDetCare == 0 else float(detMatched) / numDetCare
|
||||
|
||||
hmean = 0 if (precision + recall) == 0 else 2.0 * \
|
||||
precision * recall / (precision + recall)
|
||||
precision * recall / (precision + recall)
|
||||
|
||||
matchedSum += detMatched
|
||||
numGlobalCareGt += numGtCare
|
||||
|
@ -200,7 +200,8 @@ class DetectionIoUEvaluator(object):
|
|||
methodPrecision = 0 if numGlobalCareDet == 0 else float(
|
||||
matchedSum) / numGlobalCareDet
|
||||
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
|
||||
methodRecall * methodPrecision / (methodRecall + methodPrecision)
|
||||
methodRecall * methodPrecision / (
|
||||
methodRecall + methodPrecision)
|
||||
# print(methodRecall, methodPrecision, methodHmean)
|
||||
# sys.exit(-1)
|
||||
methodMetrics = {
|
||||
|
|
|
@ -26,6 +26,9 @@ def build_backbone(config, model_type):
|
|||
from .rec_resnet_vd import ResNet
|
||||
from .rec_resnet_fpn import ResNetFPN
|
||||
support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
|
||||
elif model_type == 'e2e':
|
||||
from .e2e_resnet_vd_pg import ResNet
|
||||
support_dict = ['ResNet']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -0,0 +1,265 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import ParamAttr
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
__all__ = ["ResNet"]
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None, ):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2b")
|
||||
self.conv2 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
act=None,
|
||||
name=name + "_branch2c")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels * 4,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
conv2 = self.conv2(conv1)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv2)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class BasicBlock(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride,
|
||||
shortcut=True,
|
||||
if_first=False,
|
||||
name=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.stride = stride
|
||||
self.conv0 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
act='relu',
|
||||
name=name + "_branch2a")
|
||||
self.conv1 = ConvBNLayer(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
act=None,
|
||||
name=name + "_branch2b")
|
||||
|
||||
if not shortcut:
|
||||
self.short = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
is_vd_mode=False if if_first else True,
|
||||
name=name + "_branch1")
|
||||
|
||||
self.shortcut = shortcut
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self.conv0(inputs)
|
||||
conv1 = self.conv1(y)
|
||||
|
||||
if self.shortcut:
|
||||
short = inputs
|
||||
else:
|
||||
short = self.short(inputs)
|
||||
y = paddle.add(x=short, y=conv1)
|
||||
y = F.relu(y)
|
||||
return y
|
||||
|
||||
|
||||
class ResNet(nn.Layer):
|
||||
def __init__(self, in_channels=3, layers=50, **kwargs):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.layers = layers
|
||||
supported_layers = [18, 34, 50, 101, 152, 200]
|
||||
assert layers in supported_layers, \
|
||||
"supported layers are {} but input layer is {}".format(
|
||||
supported_layers, layers)
|
||||
|
||||
if layers == 18:
|
||||
depth = [2, 2, 2, 2]
|
||||
elif layers == 34 or layers == 50:
|
||||
# depth = [3, 4, 6, 3]
|
||||
depth = [3, 4, 6, 3, 3]
|
||||
elif layers == 101:
|
||||
depth = [3, 4, 23, 3]
|
||||
elif layers == 152:
|
||||
depth = [3, 8, 36, 3]
|
||||
elif layers == 200:
|
||||
depth = [3, 12, 48, 3]
|
||||
num_channels = [64, 256, 512, 1024,
|
||||
2048] if layers >= 50 else [64, 64, 128, 256]
|
||||
num_filters = [64, 128, 256, 512, 512]
|
||||
|
||||
self.conv1_1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
act='relu',
|
||||
name="conv1_1")
|
||||
self.pool2d_max = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.stages = []
|
||||
self.out_channels = [3, 64]
|
||||
# num_filters = [64, 128, 256, 512, 512]
|
||||
if layers >= 50:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
if layers in [101, 152] and block == 2:
|
||||
if i == 0:
|
||||
conv_name = "res" + str(block + 2) + "a"
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + "b" + str(i)
|
||||
else:
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
bottleneck_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BottleneckBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block] * 4,
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(bottleneck_block)
|
||||
self.out_channels.append(num_filters[block] * 4)
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
else:
|
||||
for block in range(len(depth)):
|
||||
block_list = []
|
||||
shortcut = False
|
||||
for i in range(depth[block]):
|
||||
conv_name = "res" + str(block + 2) + chr(97 + i)
|
||||
basic_block = self.add_sublayer(
|
||||
'bb_%d_%d' % (block, i),
|
||||
BasicBlock(
|
||||
in_channels=num_channels[block]
|
||||
if i == 0 else num_filters[block],
|
||||
out_channels=num_filters[block],
|
||||
stride=2 if i == 0 and block != 0 else 1,
|
||||
shortcut=shortcut,
|
||||
if_first=block == i == 0,
|
||||
name=conv_name))
|
||||
shortcut = True
|
||||
block_list.append(basic_block)
|
||||
self.out_channels.append(num_filters[block])
|
||||
self.stages.append(nn.Sequential(*block_list))
|
||||
|
||||
def forward(self, inputs):
|
||||
out = [inputs]
|
||||
y = self.conv1_1(inputs)
|
||||
out.append(y)
|
||||
y = self.pool2d_max(y)
|
||||
for block in self.stages:
|
||||
y = block(y)
|
||||
out.append(y)
|
||||
return out
|
|
@ -20,6 +20,7 @@ def build_head(config):
|
|||
from .det_db_head import DBHead
|
||||
from .det_east_head import EASTHead
|
||||
from .det_sast_head import SASTHead
|
||||
from .e2e_pg_head import PGHead
|
||||
|
||||
# rec head
|
||||
from .rec_ctc_head import CTCHead
|
||||
|
@ -30,8 +31,8 @@ def build_head(config):
|
|||
from .cls_head import ClsHead
|
||||
support_dict = [
|
||||
'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
|
||||
'SRNHead'
|
||||
]
|
||||
'SRNHead', 'PGHead']
|
||||
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('head only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,253 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance",
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PGHead(nn.Layer):
|
||||
"""
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(PGHead, self).__init__()
|
||||
self.conv_f_score1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(1))
|
||||
self.conv_f_score2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(2))
|
||||
self.conv_f_score3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_score{}".format(3))
|
||||
|
||||
self.conv1 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=1,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_score{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
self.conv_f_boder1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(1))
|
||||
self.conv_f_boder2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(2))
|
||||
self.conv_f_boder3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_boder{}".format(3))
|
||||
self.conv2 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=4,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_boder{}".format(4)),
|
||||
bias_attr=False)
|
||||
self.conv_f_char1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(1))
|
||||
self.conv_f_char2 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(2))
|
||||
self.conv_f_char3 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(3))
|
||||
self.conv_f_char4 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(4))
|
||||
self.conv_f_char5 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=256,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_char{}".format(5))
|
||||
self.conv3 = nn.Conv2D(
|
||||
in_channels=256,
|
||||
out_channels=37,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_char{}".format(6)),
|
||||
bias_attr=False)
|
||||
|
||||
self.conv_f_direc1 = ConvBNLayer(
|
||||
in_channels=in_channels,
|
||||
out_channels=64,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(1))
|
||||
self.conv_f_direc2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(2))
|
||||
self.conv_f_direc3 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
act='relu',
|
||||
name="conv_f_direc{}".format(3))
|
||||
self.conv4 = nn.Conv2D(
|
||||
in_channels=128,
|
||||
out_channels=2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=1,
|
||||
weight_attr=ParamAttr(name="conv_f_direc{}".format(4)),
|
||||
bias_attr=False)
|
||||
|
||||
def forward(self, x):
|
||||
f_score = self.conv_f_score1(x)
|
||||
f_score = self.conv_f_score2(f_score)
|
||||
f_score = self.conv_f_score3(f_score)
|
||||
f_score = self.conv1(f_score)
|
||||
f_score = F.sigmoid(f_score)
|
||||
|
||||
# f_border
|
||||
f_border = self.conv_f_boder1(x)
|
||||
f_border = self.conv_f_boder2(f_border)
|
||||
f_border = self.conv_f_boder3(f_border)
|
||||
f_border = self.conv2(f_border)
|
||||
|
||||
f_char = self.conv_f_char1(x)
|
||||
f_char = self.conv_f_char2(f_char)
|
||||
f_char = self.conv_f_char3(f_char)
|
||||
f_char = self.conv_f_char4(f_char)
|
||||
f_char = self.conv_f_char5(f_char)
|
||||
f_char = self.conv3(f_char)
|
||||
|
||||
f_direction = self.conv_f_direc1(x)
|
||||
f_direction = self.conv_f_direc2(f_direction)
|
||||
f_direction = self.conv_f_direc3(f_direction)
|
||||
f_direction = self.conv4(f_direction)
|
||||
|
||||
predicts = {}
|
||||
predicts['f_score'] = f_score
|
||||
predicts['f_border'] = f_border
|
||||
predicts['f_char'] = f_char
|
||||
predicts['f_direction'] = f_direction
|
||||
return predicts
|
|
@ -38,7 +38,7 @@ class AttentionHead(nn.Layer):
|
|||
return input_ont_hot
|
||||
|
||||
def forward(self, inputs, targets=None, batch_max_length=25):
|
||||
batch_size = inputs.shape[0]
|
||||
batch_size = paddle.shape(inputs)[0]
|
||||
num_steps = batch_max_length
|
||||
|
||||
hidden = paddle.zeros((batch_size, self.hidden_size))
|
||||
|
@ -57,6 +57,9 @@ class AttentionHead(nn.Layer):
|
|||
else:
|
||||
targets = paddle.zeros(shape=[batch_size], dtype="int32")
|
||||
probs = None
|
||||
char_onehots = None
|
||||
outputs = None
|
||||
alpha = None
|
||||
|
||||
for i in range(num_steps):
|
||||
char_onehots = self._char_to_onehot(
|
||||
|
|
|
@ -14,12 +14,14 @@
|
|||
|
||||
__all__ = ['build_neck']
|
||||
|
||||
|
||||
def build_neck(config):
|
||||
from .db_fpn import DBFPN
|
||||
from .east_fpn import EASTFPN
|
||||
from .sast_fpn import SASTFPN
|
||||
from .rnn import SequenceEncoder
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder']
|
||||
from .pg_fpn import PGFPN
|
||||
support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN']
|
||||
|
||||
module_name = config.pop('name')
|
||||
assert module_name in support_dict, Exception('neck only support {}'.format(
|
||||
|
|
|
@ -0,0 +1,314 @@
|
|||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle import nn
|
||||
import paddle.nn.functional as F
|
||||
from paddle import ParamAttr
|
||||
|
||||
|
||||
class ConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
groups=1,
|
||||
is_vd_mode=False,
|
||||
act=None,
|
||||
name=None):
|
||||
super(ConvBNLayer, self).__init__()
|
||||
|
||||
self.is_vd_mode = is_vd_mode
|
||||
self._pool2d_avg = nn.AvgPool2D(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True)
|
||||
self._conv = nn.Conv2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
if name == "conv1":
|
||||
bn_name = "bn_" + name
|
||||
else:
|
||||
bn_name = "bn" + name[3:]
|
||||
self._batch_norm = nn.BatchNorm(
|
||||
out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + '_scale'),
|
||||
bias_attr=ParamAttr(bn_name + '_offset'),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance',
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, inputs):
|
||||
y = self._conv(inputs)
|
||||
y = self._batch_norm(y)
|
||||
return y
|
||||
|
||||
|
||||
class DeConvBNLayer(nn.Layer):
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1,
|
||||
groups=1,
|
||||
if_act=True,
|
||||
act=None,
|
||||
name=None):
|
||||
super(DeConvBNLayer, self).__init__()
|
||||
|
||||
self.if_act = if_act
|
||||
self.act = act
|
||||
self.deconv = nn.Conv2DTranspose(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=groups,
|
||||
weight_attr=ParamAttr(name=name + '_weights'),
|
||||
bias_attr=False)
|
||||
self.bn = nn.BatchNorm(
|
||||
num_channels=out_channels,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name="bn_" + name + "_scale"),
|
||||
bias_attr=ParamAttr(name="bn_" + name + "_offset"),
|
||||
moving_mean_name="bn_" + name + "_mean",
|
||||
moving_variance_name="bn_" + name + "_variance",
|
||||
use_global_stats=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.deconv(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class PGFPN(nn.Layer):
|
||||
def __init__(self, in_channels, **kwargs):
|
||||
super(PGFPN, self).__init__()
|
||||
num_inputs = [2048, 2048, 1024, 512, 256]
|
||||
num_outputs = [256, 256, 192, 192, 128]
|
||||
self.out_channels = 128
|
||||
self.conv_bn_layer_1 = ConvBNLayer(
|
||||
in_channels=3,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d1')
|
||||
self.conv_bn_layer_2 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d2')
|
||||
self.conv_bn_layer_3 = ConvBNLayer(
|
||||
in_channels=256,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d3')
|
||||
self.conv_bn_layer_4 = ConvBNLayer(
|
||||
in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act=None,
|
||||
name='FPN_d4')
|
||||
self.conv_bn_layer_5 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=64,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='FPN_d5')
|
||||
self.conv_bn_layer_6 = ConvBNLayer(
|
||||
in_channels=64,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
act=None,
|
||||
name='FPN_d6')
|
||||
self.conv_bn_layer_7 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name='FPN_d7')
|
||||
self.conv_bn_layer_8 = ConvBNLayer(
|
||||
in_channels=128,
|
||||
out_channels=128,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name='FPN_d8')
|
||||
|
||||
self.conv_h0 = ConvBNLayer(
|
||||
in_channels=num_inputs[0],
|
||||
out_channels=num_outputs[0],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(0))
|
||||
self.conv_h1 = ConvBNLayer(
|
||||
in_channels=num_inputs[1],
|
||||
out_channels=num_outputs[1],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(1))
|
||||
self.conv_h2 = ConvBNLayer(
|
||||
in_channels=num_inputs[2],
|
||||
out_channels=num_outputs[2],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(2))
|
||||
self.conv_h3 = ConvBNLayer(
|
||||
in_channels=num_inputs[3],
|
||||
out_channels=num_outputs[3],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(3))
|
||||
self.conv_h4 = ConvBNLayer(
|
||||
in_channels=num_inputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_h{}".format(4))
|
||||
|
||||
self.dconv0 = DeConvBNLayer(
|
||||
in_channels=num_outputs[0],
|
||||
out_channels=num_outputs[0 + 1],
|
||||
name="dconv_{}".format(0))
|
||||
self.dconv1 = DeConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[1 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(1))
|
||||
self.dconv2 = DeConvBNLayer(
|
||||
in_channels=num_outputs[2],
|
||||
out_channels=num_outputs[2 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(2))
|
||||
self.dconv3 = DeConvBNLayer(
|
||||
in_channels=num_outputs[3],
|
||||
out_channels=num_outputs[3 + 1],
|
||||
act=None,
|
||||
name="dconv_{}".format(3))
|
||||
self.conv_g1 = ConvBNLayer(
|
||||
in_channels=num_outputs[1],
|
||||
out_channels=num_outputs[1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(1))
|
||||
self.conv_g2 = ConvBNLayer(
|
||||
in_channels=num_outputs[2],
|
||||
out_channels=num_outputs[2],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(2))
|
||||
self.conv_g3 = ConvBNLayer(
|
||||
in_channels=num_outputs[3],
|
||||
out_channels=num_outputs[3],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(3))
|
||||
self.conv_g4 = ConvBNLayer(
|
||||
in_channels=num_outputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
act='relu',
|
||||
name="conv_g{}".format(4))
|
||||
self.convf = ConvBNLayer(
|
||||
in_channels=num_outputs[4],
|
||||
out_channels=num_outputs[4],
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
act=None,
|
||||
name="conv_f{}".format(4))
|
||||
|
||||
def forward(self, x):
|
||||
c0, c1, c2, c3, c4, c5, c6 = x
|
||||
# FPN_Down_Fusion
|
||||
f = [c0, c1, c2]
|
||||
g = [None, None, None]
|
||||
h = [None, None, None]
|
||||
h[0] = self.conv_bn_layer_1(f[0])
|
||||
h[1] = self.conv_bn_layer_2(f[1])
|
||||
h[2] = self.conv_bn_layer_3(f[2])
|
||||
|
||||
g[0] = self.conv_bn_layer_4(h[0])
|
||||
g[1] = paddle.add(g[0], h[1])
|
||||
g[1] = F.relu(g[1])
|
||||
g[1] = self.conv_bn_layer_5(g[1])
|
||||
g[1] = self.conv_bn_layer_6(g[1])
|
||||
|
||||
g[2] = paddle.add(g[1], h[2])
|
||||
g[2] = F.relu(g[2])
|
||||
g[2] = self.conv_bn_layer_7(g[2])
|
||||
f_down = self.conv_bn_layer_8(g[2])
|
||||
|
||||
# FPN UP Fusion
|
||||
f1 = [c6, c5, c4, c3, c2]
|
||||
g = [None, None, None, None, None]
|
||||
h = [None, None, None, None, None]
|
||||
h[0] = self.conv_h0(f1[0])
|
||||
h[1] = self.conv_h1(f1[1])
|
||||
h[2] = self.conv_h2(f1[2])
|
||||
h[3] = self.conv_h3(f1[3])
|
||||
h[4] = self.conv_h4(f1[4])
|
||||
|
||||
g[0] = self.dconv0(h[0])
|
||||
g[1] = paddle.add(g[0], h[1])
|
||||
g[1] = F.relu(g[1])
|
||||
g[1] = self.conv_g1(g[1])
|
||||
g[1] = self.dconv1(g[1])
|
||||
|
||||
g[2] = paddle.add(g[1], h[2])
|
||||
g[2] = F.relu(g[2])
|
||||
g[2] = self.conv_g2(g[2])
|
||||
g[2] = self.dconv2(g[2])
|
||||
|
||||
g[3] = paddle.add(g[2], h[3])
|
||||
g[3] = F.relu(g[3])
|
||||
g[3] = self.conv_g3(g[3])
|
||||
g[3] = self.dconv3(g[3])
|
||||
|
||||
g[4] = paddle.add(x=g[3], y=h[4])
|
||||
g[4] = F.relu(g[4])
|
||||
g[4] = self.conv_g4(g[4])
|
||||
f_up = self.convf(g[4])
|
||||
f_common = paddle.add(f_down, f_up)
|
||||
f_common = F.relu(f_common)
|
||||
return f_common
|
|
@ -28,10 +28,11 @@ def build_post_process(config, global_config=None):
|
|||
from .sast_postprocess import SASTPostProcess
|
||||
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
|
||||
support_dict = [
|
||||
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
|
||||
'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess'
|
||||
]
|
||||
|
||||
config = copy.deepcopy(config)
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
||||
from ppocr.utils.e2e_utils.extract_textpoint import *
|
||||
from ppocr.utils.e2e_utils.visual import *
|
||||
import paddle
|
||||
|
||||
|
||||
class PGPostProcess(object):
|
||||
"""
|
||||
The post process for PGNet.
|
||||
"""
|
||||
|
||||
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
|
||||
self.Lexicon_Table = get_dict(character_dict_path)
|
||||
self.valid_set = valid_set
|
||||
self.score_thresh = score_thresh
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
p_score = outs_dict['f_score']
|
||||
p_border = outs_dict['f_border']
|
||||
p_char = outs_dict['f_char']
|
||||
p_direction = outs_dict['f_direction']
|
||||
if isinstance(p_score, paddle.Tensor):
|
||||
p_score = p_score[0].numpy()
|
||||
p_border = p_border[0].numpy()
|
||||
p_direction = p_direction[0].numpy()
|
||||
p_char = p_char[0].numpy()
|
||||
else:
|
||||
p_score = p_score[0]
|
||||
p_border = p_border[0]
|
||||
p_direction = p_direction[0]
|
||||
p_char = p_char[0]
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[0]
|
||||
is_curved = self.valid_set == "totaltext"
|
||||
instance_yxs_list = generate_pivot_list(
|
||||
p_score,
|
||||
p_char,
|
||||
p_direction,
|
||||
score_thresh=self.score_thresh,
|
||||
is_backbone=True,
|
||||
is_curved=is_curved)
|
||||
p_char = paddle.to_tensor(np.expand_dims(p_char, axis=0))
|
||||
char_seq_idx_set = []
|
||||
for i in range(len(instance_yxs_list)):
|
||||
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
|
||||
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
|
||||
feature_seq = paddle.gather_nd(f_char_map, gather_info_lod)
|
||||
feature_seq = np.expand_dims(feature_seq.numpy(), axis=0)
|
||||
feature_len = [len(feature_seq[0])]
|
||||
featyre_seq = paddle.to_tensor(feature_seq)
|
||||
feature_len = np.array([feature_len]).astype(np.int64)
|
||||
length = paddle.to_tensor(feature_len)
|
||||
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
|
||||
input=featyre_seq, blank=36, input_length=length)
|
||||
seq_pred_str = seq_pred[0].numpy().tolist()[0]
|
||||
seq_len = seq_pred[1].numpy()[0][0]
|
||||
temp_t = []
|
||||
for c in seq_pred_str[:seq_len]:
|
||||
temp_t.append(c)
|
||||
char_seq_idx_set.append(temp_t)
|
||||
seq_strs = []
|
||||
for char_idx_set in char_seq_idx_set:
|
||||
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
|
||||
seq_strs.append(pr_str)
|
||||
poly_list = []
|
||||
keep_str_list = []
|
||||
all_point_list = []
|
||||
all_point_pair_list = []
|
||||
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
|
||||
if len(yx_center_line) == 1:
|
||||
yx_center_line.append(yx_center_line[-1])
|
||||
|
||||
offset_expand = 1.0
|
||||
if self.valid_set == 'totaltext':
|
||||
offset_expand = 1.2
|
||||
|
||||
point_pair_list = []
|
||||
for batch_id, y, x in yx_center_line:
|
||||
offset = p_border[:, y, x].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(
|
||||
offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(
|
||||
offset_length * (offset_expand - 1),
|
||||
a_min=0.5,
|
||||
a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
all_point_list.append([
|
||||
int(round(x * 4.0 / ratio_w)),
|
||||
int(round(y * 4.0 / ratio_h))
|
||||
])
|
||||
all_point_pair_list.append(point_pair.round().astype(np.int32)
|
||||
.tolist())
|
||||
|
||||
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
|
||||
detected_poly = expand_poly_along_width(
|
||||
detected_poly, shrink_ratio_of_width=0.2)
|
||||
detected_poly[:, 0] = np.clip(
|
||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(
|
||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
|
||||
if len(keep_str) < 2:
|
||||
continue
|
||||
|
||||
keep_str_list.append(keep_str)
|
||||
if self.valid_set == 'partvgg':
|
||||
middle_point = len(detected_poly) // 2
|
||||
detected_poly = detected_poly[
|
||||
[0, middle_point - 1, middle_point, -1], :]
|
||||
poly_list.append(detected_poly)
|
||||
elif self.valid_set == 'totaltext':
|
||||
poly_list.append(detected_poly)
|
||||
else:
|
||||
print('--> Not supported format.')
|
||||
exit(-1)
|
||||
data = {
|
||||
'points': poly_list,
|
||||
'strs': keep_str_list,
|
||||
}
|
||||
return data
|
|
@ -18,6 +18,7 @@ from __future__ import print_function
|
|||
|
||||
import os
|
||||
import sys
|
||||
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '..'))
|
||||
|
@ -49,12 +50,12 @@ class SASTPostProcess(object):
|
|||
self.shrink_ratio_of_width = shrink_ratio_of_width
|
||||
self.expand_scale = expand_scale
|
||||
self.tcl_map_thresh = tcl_map_thresh
|
||||
|
||||
|
||||
# c++ la-nms is faster, but only support python 3.5
|
||||
self.is_python35 = False
|
||||
if sys.version_info.major == 3 and sys.version_info.minor == 5:
|
||||
self.is_python35 = True
|
||||
|
||||
|
||||
def point_pair2poly(self, point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
|
@ -66,31 +67,42 @@ class SASTPostProcess(object):
|
|||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2)
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array([poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array([poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]], dtype=np.float32)
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio,
|
||||
1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0,
|
||||
right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
|
@ -100,7 +112,7 @@ class SASTPostProcess(object):
|
|||
def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map):
|
||||
"""Restore quad."""
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
|
||||
# Sort the text boxes via the y axis
|
||||
xy_text = xy_text[np.argsort(xy_text[:, 1])]
|
||||
|
@ -112,7 +124,7 @@ class SASTPostProcess(object):
|
|||
point_num = int(tvo_map.shape[-1] / 2)
|
||||
assert point_num == 4
|
||||
tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :]
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
xy_text_tile = np.tile(xy_text, (1, point_num)) # (n, point_num * 2)
|
||||
quads = xy_text_tile - tvo_map
|
||||
|
||||
return scores, quads, xy_text
|
||||
|
@ -121,14 +133,12 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
compute area of a quad.
|
||||
"""
|
||||
edge = [
|
||||
(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])
|
||||
]
|
||||
edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]),
|
||||
(quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]),
|
||||
(quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]),
|
||||
(quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
|
||||
def nms(self, dets):
|
||||
if self.is_python35:
|
||||
import lanms
|
||||
|
@ -141,7 +151,7 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
Cluster pixels in tcl_map based on quads.
|
||||
"""
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_count = quads.shape[0] + 1 # contain background
|
||||
instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32)
|
||||
if instance_count == 1:
|
||||
return instance_count, instance_label_map
|
||||
|
@ -149,18 +159,19 @@ class SASTPostProcess(object):
|
|||
# predict text center
|
||||
xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh)
|
||||
n = xy_text.shape[0]
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
xy_text = xy_text[:, ::-1] # (n, 2)
|
||||
tco = tco_map[xy_text[:, 1], xy_text[:, 0], :] # (n, 2)
|
||||
pred_tc = xy_text - tco
|
||||
|
||||
|
||||
# get gt text center
|
||||
m = quads.shape[0]
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
gt_tc = np.mean(quads, axis=1) # (m, 2)
|
||||
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], (1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :],
|
||||
(1, m, 1)) # (n, m, 2)
|
||||
gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1)) # (n, m, 2)
|
||||
dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2) # (n, m)
|
||||
xy_text_assign = np.argmin(dist_mat, axis=1) + 1 # (n,)
|
||||
|
||||
instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign
|
||||
return instance_count, instance_label_map
|
||||
|
@ -169,26 +180,47 @@ class SASTPostProcess(object):
|
|||
"""
|
||||
Estimate sample points number.
|
||||
"""
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
eh = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2])) / 2.0
|
||||
ew = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
|
||||
dense_sample_pts_num = max(2, int(ew))
|
||||
dense_xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, dense_sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
dense_xy_center_line = xy_text[np.linspace(
|
||||
0,
|
||||
xy_text.shape[0] - 1,
|
||||
dense_sample_pts_num,
|
||||
endpoint=True,
|
||||
dtype=np.float32).astype(np.int32)]
|
||||
|
||||
dense_xy_center_line_diff = dense_xy_center_line[1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(np.linalg.norm(dense_xy_center_line_diff, axis=1))
|
||||
dense_xy_center_line_diff = dense_xy_center_line[
|
||||
1:] - dense_xy_center_line[:-1]
|
||||
estimate_arc_len = np.sum(
|
||||
np.linalg.norm(
|
||||
dense_xy_center_line_diff, axis=1))
|
||||
|
||||
sample_pts_num = max(2, int(estimate_arc_len / eh))
|
||||
return sample_pts_num
|
||||
|
||||
def detect_sast(self, tcl_map, tvo_map, tbo_map, tco_map, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=0.3, tcl_map_thresh=0.5, offset_expand=1.0, out_strid=4.0):
|
||||
def detect_sast(self,
|
||||
tcl_map,
|
||||
tvo_map,
|
||||
tbo_map,
|
||||
tco_map,
|
||||
ratio_w,
|
||||
ratio_h,
|
||||
src_w,
|
||||
src_h,
|
||||
shrink_ratio_of_width=0.3,
|
||||
tcl_map_thresh=0.5,
|
||||
offset_expand=1.0,
|
||||
out_strid=4.0):
|
||||
"""
|
||||
first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys
|
||||
"""
|
||||
# restore quad
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, tvo_map)
|
||||
scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh,
|
||||
tvo_map)
|
||||
dets = np.hstack((quads, scores)).astype(np.float32, copy=False)
|
||||
dets = self.nms(dets)
|
||||
if dets.shape[0] == 0:
|
||||
|
@ -202,7 +234,8 @@ class SASTPostProcess(object):
|
|||
|
||||
# instance segmentation
|
||||
# instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
instance_count, instance_label_map = self.cluster_by_quads_tco(
|
||||
tcl_map, tcl_map_thresh, quads, tco_map)
|
||||
|
||||
# restore single poly with tcl instance.
|
||||
poly_list = []
|
||||
|
@ -212,10 +245,10 @@ class SASTPostProcess(object):
|
|||
q_area = quad_areas[instance_idx - 1]
|
||||
if q_area < 5:
|
||||
continue
|
||||
|
||||
|
||||
#
|
||||
len1 = float(np.linalg.norm(quad[0] -quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] -quad[2]))
|
||||
len1 = float(np.linalg.norm(quad[0] - quad[1]))
|
||||
len2 = float(np.linalg.norm(quad[1] - quad[2]))
|
||||
min_len = min(len1, len2)
|
||||
if min_len < 3:
|
||||
continue
|
||||
|
@ -225,16 +258,18 @@ class SASTPostProcess(object):
|
|||
continue
|
||||
|
||||
# filter low confidence instance
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0]
|
||||
if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
# if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05:
|
||||
continue
|
||||
|
||||
# sort xy_text
|
||||
left_center_pt = np.array([[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array([[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
left_center_pt = np.array(
|
||||
[[(quad[0, 0] + quad[-1, 0]) / 2.0,
|
||||
(quad[0, 1] + quad[-1, 1]) / 2.0]]) # (1, 2)
|
||||
right_center_pt = np.array(
|
||||
[[(quad[1, 0] + quad[2, 0]) / 2.0,
|
||||
(quad[1, 1] + quad[2, 1]) / 2.0]]) # (1, 2)
|
||||
proj_unit_vec = (right_center_pt - left_center_pt) / \
|
||||
(np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
|
||||
proj_value = np.sum(xy_text * proj_unit_vec, axis=1)
|
||||
|
@ -245,33 +280,45 @@ class SASTPostProcess(object):
|
|||
sample_pts_num = self.estimate_sample_pts_num(quad, xy_text)
|
||||
else:
|
||||
sample_pts_num = self.sample_pts_num
|
||||
xy_center_line = xy_text[np.linspace(0, xy_text.shape[0] - 1, sample_pts_num,
|
||||
endpoint=True, dtype=np.float32).astype(np.int32)]
|
||||
xy_center_line = xy_text[np.linspace(
|
||||
0,
|
||||
xy_text.shape[0] - 1,
|
||||
sample_pts_num,
|
||||
endpoint=True,
|
||||
dtype=np.float32).astype(np.int32)]
|
||||
|
||||
point_pair_list = []
|
||||
for x, y in xy_center_line:
|
||||
# get corresponding offset
|
||||
offset = tbo_map[y, x, :].reshape(2, 2)
|
||||
if offset_expand != 1.0:
|
||||
offset_length = np.linalg.norm(offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(offset_length * (offset_expand - 1), a_min=0.5, a_max=3.0)
|
||||
offset_length = np.linalg.norm(
|
||||
offset, axis=1, keepdims=True)
|
||||
expand_length = np.clip(
|
||||
offset_length * (offset_expand - 1),
|
||||
a_min=0.5,
|
||||
a_max=3.0)
|
||||
offset_detal = offset / offset_length * expand_length
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
offset = offset + offset_detal
|
||||
# original point
|
||||
ori_yx = np.array([y, x], dtype=np.float32)
|
||||
point_pair = (ori_yx + offset)[:, ::-1]* out_strid / np.array([ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array(
|
||||
[ratio_w, ratio_h]).reshape(-1, 2)
|
||||
point_pair_list.append(point_pair)
|
||||
|
||||
# ndarry: (x, 2), expand poly along width
|
||||
detected_poly = self.point_pair2poly(point_pair_list)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly, shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
detected_poly = self.expand_poly_along_width(detected_poly,
|
||||
shrink_ratio_of_width)
|
||||
detected_poly[:, 0] = np.clip(
|
||||
detected_poly[:, 0], a_min=0, a_max=src_w)
|
||||
detected_poly[:, 1] = np.clip(
|
||||
detected_poly[:, 1], a_min=0, a_max=src_h)
|
||||
poly_list.append(detected_poly)
|
||||
|
||||
return poly_list
|
||||
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
def __call__(self, outs_dict, shape_list):
|
||||
score_list = outs_dict['f_score']
|
||||
border_list = outs_dict['f_border']
|
||||
tvo_list = outs_dict['f_tvo']
|
||||
|
@ -281,20 +328,28 @@ class SASTPostProcess(object):
|
|||
border_list = border_list.numpy()
|
||||
tvo_list = tvo_list.numpy()
|
||||
tco_list = tco_list.numpy()
|
||||
|
||||
|
||||
img_num = len(shape_list)
|
||||
poly_lists = []
|
||||
for ino in range(img_num):
|
||||
p_score = score_list[ino].transpose((1,2,0))
|
||||
p_border = border_list[ino].transpose((1,2,0))
|
||||
p_tvo = tvo_list[ino].transpose((1,2,0))
|
||||
p_tco = tco_list[ino].transpose((1,2,0))
|
||||
p_score = score_list[ino].transpose((1, 2, 0))
|
||||
p_border = border_list[ino].transpose((1, 2, 0))
|
||||
p_tvo = tvo_list[ino].transpose((1, 2, 0))
|
||||
p_tco = tco_list[ino].transpose((1, 2, 0))
|
||||
src_h, src_w, ratio_h, ratio_w = shape_list[ino]
|
||||
|
||||
poly_list = self.detect_sast(p_score, p_tvo, p_border, p_tco, ratio_w, ratio_h, src_w, src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh, offset_expand=self.expand_scale)
|
||||
poly_list = self.detect_sast(
|
||||
p_score,
|
||||
p_tvo,
|
||||
p_border,
|
||||
p_tco,
|
||||
ratio_w,
|
||||
ratio_h,
|
||||
src_w,
|
||||
src_h,
|
||||
shrink_ratio_of_width=self.shrink_ratio_of_width,
|
||||
tcl_map_thresh=self.tcl_map_thresh,
|
||||
offset_expand=self.expand_scale)
|
||||
poly_lists.append({'points': np.array(poly_list)})
|
||||
|
||||
return poly_lists
|
||||
|
||||
|
|
|
@ -0,0 +1,458 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
|
||||
|
||||
|
||||
def get_socre(gt_dict, pred_dict):
|
||||
allInputs = 1
|
||||
|
||||
def input_reading_mod(pred_dict):
|
||||
"""This helper reads input from txt files"""
|
||||
det = []
|
||||
n = len(pred_dict)
|
||||
for i in range(n):
|
||||
points = pred_dict[i]['points']
|
||||
text = pred_dict[i]['text']
|
||||
point = ",".join(map(str, points.reshape(-1, )))
|
||||
det.append([point, text])
|
||||
return det
|
||||
|
||||
def gt_reading_mod(gt_dict):
|
||||
"""This helper reads groundtruths from mat files"""
|
||||
gt = []
|
||||
n = len(gt_dict)
|
||||
for i in range(n):
|
||||
points = gt_dict[i]['points']
|
||||
h = len(points)
|
||||
text = gt_dict[i]['text']
|
||||
xx = [
|
||||
np.array(
|
||||
['x:'], dtype='<U2'), 0, np.array(
|
||||
['y:'], dtype='<U2'), 0, np.array(
|
||||
['#'], dtype='<U1'), np.array(
|
||||
['#'], dtype='<U1')
|
||||
]
|
||||
t_x, t_y = [], []
|
||||
for j in range(h):
|
||||
t_x.append(points[j][0])
|
||||
t_y.append(points[j][1])
|
||||
xx[1] = np.array([t_x], dtype='int16')
|
||||
xx[3] = np.array([t_y], dtype='int16')
|
||||
if text != "" and "#" not in text:
|
||||
xx[4] = np.array([text], dtype='U{}'.format(len(text)))
|
||||
xx[5] = np.array(['c'], dtype='<U1')
|
||||
gt.append(xx)
|
||||
return gt
|
||||
|
||||
def detection_filtering(detections, groundtruths, threshold=0.5):
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if (gt[5] == '#') and (gt[1].shape[1] > 1):
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
|
||||
if det_gt_iou > threshold:
|
||||
detections[det_id] = []
|
||||
|
||||
detections[:] = [item for item in detections if item != []]
|
||||
return detections
|
||||
|
||||
def sigma_calculation(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
sigma = inter_area / gt_area
|
||||
"""
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(gt_x, gt_y)), 2)
|
||||
|
||||
def tau_calculation(det_x, det_y, gt_x, gt_y):
|
||||
if area(det_x, det_y) == 0.0:
|
||||
return 0
|
||||
return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
|
||||
area(det_x, det_y)), 2)
|
||||
|
||||
##############################Initialization###################################
|
||||
# global_sigma = []
|
||||
# global_tau = []
|
||||
# global_pred_str = []
|
||||
# global_gt_str = []
|
||||
###############################################################################
|
||||
|
||||
for input_id in range(allInputs):
|
||||
if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
|
||||
input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
|
||||
input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
|
||||
and (input_id != 'Deteval_result_non_curved.txt'):
|
||||
detections = input_reading_mod(pred_dict)
|
||||
groundtruths = gt_reading_mod(gt_dict)
|
||||
detections = detection_filtering(
|
||||
detections,
|
||||
groundtruths) # filters detections overlapping with DC area
|
||||
dc_id = []
|
||||
for i in range(len(groundtruths)):
|
||||
if groundtruths[i][5] == '#':
|
||||
dc_id.append(i)
|
||||
cnt = 0
|
||||
for a in dc_id:
|
||||
num = a - cnt
|
||||
del groundtruths[num]
|
||||
cnt += 1
|
||||
|
||||
local_sigma_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_tau_table = np.zeros((len(groundtruths), len(detections)))
|
||||
local_pred_str = {}
|
||||
local_gt_str = {}
|
||||
|
||||
for gt_id, gt in enumerate(groundtruths):
|
||||
if len(detections) > 0:
|
||||
for det_id, detection in enumerate(detections):
|
||||
detection_orig = detection
|
||||
detection = [float(x) for x in detection[0].split(',')]
|
||||
detection = list(map(int, detection))
|
||||
pred_seq_str = detection_orig[1].strip()
|
||||
det_x = detection[0::2]
|
||||
det_y = detection[1::2]
|
||||
gt_x = list(map(int, np.squeeze(gt[1])))
|
||||
gt_y = list(map(int, np.squeeze(gt[3])))
|
||||
gt_seq_str = str(gt[4].tolist()[0])
|
||||
|
||||
local_sigma_table[gt_id, det_id] = sigma_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_tau_table[gt_id, det_id] = tau_calculation(
|
||||
det_x, det_y, gt_x, gt_y)
|
||||
local_pred_str[det_id] = pred_seq_str
|
||||
local_gt_str[gt_id] = gt_seq_str
|
||||
|
||||
global_sigma = local_sigma_table
|
||||
global_tau = local_tau_table
|
||||
global_pred_str = local_pred_str
|
||||
global_gt_str = local_gt_str
|
||||
|
||||
single_data = {}
|
||||
single_data['sigma'] = global_sigma
|
||||
single_data['global_tau'] = global_tau
|
||||
single_data['global_pred_str'] = global_pred_str
|
||||
single_data['global_gt_str'] = global_gt_str
|
||||
return single_data
|
||||
|
||||
|
||||
def combine_results(all_data):
|
||||
tr = 0.7
|
||||
tp = 0.6
|
||||
fsc_k = 0.8
|
||||
k = 2
|
||||
global_sigma = []
|
||||
global_tau = []
|
||||
global_pred_str = []
|
||||
global_gt_str = []
|
||||
for data in all_data:
|
||||
global_sigma.append(data['sigma'])
|
||||
global_tau.append(data['global_tau'])
|
||||
global_pred_str.append(data['global_pred_str'])
|
||||
global_gt_str.append(data['global_gt_str'])
|
||||
|
||||
global_accumulative_recall = 0
|
||||
global_accumulative_precision = 0
|
||||
total_num_gt = 0
|
||||
total_num_det = 0
|
||||
hit_str_count = 0
|
||||
hit_count = 0
|
||||
|
||||
def one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
gt_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[gt_id, :] > tr)
|
||||
gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
gt_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[gt_id, :] > tp)
|
||||
gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
det_matching_qualified_sigma_candidates = np.where(
|
||||
local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
|
||||
> tr)
|
||||
det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
det_matching_qualified_tau_candidates = np.where(
|
||||
local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
|
||||
tp)
|
||||
det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
|
||||
(det_matching_num_qualified_sigma_candidates == 1) and (
|
||||
det_matching_num_qualified_tau_candidates == 1):
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][matched_det_id[0].tolist()[
|
||||
0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
det_flag[0, matched_det_id] = 1
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for gt_id in range(num_gt):
|
||||
# skip the following if the groundtruth was matched
|
||||
if gt_flag[0, gt_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
|
||||
num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
|
||||
|
||||
if num_non_zero_in_sigma >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_tau_candidates = np.where((local_tau_table[
|
||||
gt_id, :] >= tp) & (det_flag[0, :] == 0))
|
||||
num_qualified_tau_candidates = qualified_tau_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_tau_candidates == 1:
|
||||
if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
|
||||
and
|
||||
(local_sigma_table[gt_id, qualified_tau_candidates] >=
|
||||
tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
|
||||
>= tr):
|
||||
gt_flag[0, gt_id] = 1
|
||||
det_flag[0, qualified_tau_candidates] = 1
|
||||
# recg start
|
||||
gt_str_cur = global_gt_str[idy][gt_id]
|
||||
pred_str_cur = global_pred_str[idy][
|
||||
qualified_tau_candidates[0].tolist()[0]]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
|
||||
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
def many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idy):
|
||||
hit_str_num = 0
|
||||
for det_id in range(num_det):
|
||||
# skip the following if the detection was matched
|
||||
if det_flag[0, det_id] > 0:
|
||||
continue
|
||||
|
||||
non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
|
||||
num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
|
||||
|
||||
if num_non_zero_in_tau >= k:
|
||||
####search for all detections that overlaps with this groundtruth
|
||||
qualified_sigma_candidates = np.where((
|
||||
local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
|
||||
num_qualified_sigma_candidates = qualified_sigma_candidates[
|
||||
0].shape[0]
|
||||
|
||||
if num_qualified_sigma_candidates == 1:
|
||||
if ((local_tau_table[qualified_sigma_candidates, det_id] >=
|
||||
tp) and
|
||||
(local_sigma_table[qualified_sigma_candidates, det_id]
|
||||
>= tr)):
|
||||
# became an one-to-one case
|
||||
global_accumulative_recall = global_accumulative_recall + 1.0
|
||||
global_accumulative_precision = global_accumulative_precision + 1.0
|
||||
local_accumulative_recall = local_accumulative_recall + 1.0
|
||||
local_accumulative_precision = local_accumulative_precision + 1.0
|
||||
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
det_flag[0, det_id] = 1
|
||||
# recg start
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[
|
||||
idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
elif (np.sum(local_tau_table[qualified_sigma_candidates,
|
||||
det_id]) >= tp):
|
||||
det_flag[0, det_id] = 1
|
||||
gt_flag[0, qualified_sigma_candidates] = 1
|
||||
# recg start
|
||||
pred_str_cur = global_pred_str[idy][det_id]
|
||||
gt_len = len(qualified_sigma_candidates[0])
|
||||
for idx in range(gt_len):
|
||||
ele_gt_id = qualified_sigma_candidates[0].tolist()[idx]
|
||||
if ele_gt_id not in global_gt_str[idy]:
|
||||
continue
|
||||
gt_str_cur = global_gt_str[idy][ele_gt_id]
|
||||
if pred_str_cur == gt_str_cur:
|
||||
hit_str_num += 1
|
||||
break
|
||||
else:
|
||||
if pred_str_cur.lower() == gt_str_cur.lower():
|
||||
hit_str_num += 1
|
||||
break
|
||||
# recg end
|
||||
|
||||
global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
global_accumulative_precision = global_accumulative_precision + fsc_k
|
||||
|
||||
local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
|
||||
local_accumulative_precision = local_accumulative_precision + fsc_k
|
||||
return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
|
||||
|
||||
for idx in range(len(global_sigma)):
|
||||
local_sigma_table = np.array(global_sigma[idx])
|
||||
local_tau_table = global_tau[idx]
|
||||
|
||||
num_gt = local_sigma_table.shape[0]
|
||||
num_det = local_sigma_table.shape[1]
|
||||
|
||||
total_num_gt = total_num_gt + num_gt
|
||||
total_num_det = total_num_det + num_det
|
||||
|
||||
local_accumulative_recall = 0
|
||||
local_accumulative_precision = 0
|
||||
gt_flag = np.zeros((1, num_gt))
|
||||
det_flag = np.zeros((1, num_det))
|
||||
|
||||
#######first check for one-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for one-to-many case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
#######then check for many-to-one case##########
|
||||
local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
|
||||
gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
|
||||
local_accumulative_recall, local_accumulative_precision,
|
||||
global_accumulative_recall, global_accumulative_precision,
|
||||
gt_flag, det_flag, idx)
|
||||
hit_str_count += hit_str_num
|
||||
|
||||
try:
|
||||
recall = global_accumulative_recall / total_num_gt
|
||||
except ZeroDivisionError:
|
||||
recall = 0
|
||||
|
||||
try:
|
||||
precision = global_accumulative_precision / total_num_det
|
||||
except ZeroDivisionError:
|
||||
precision = 0
|
||||
|
||||
try:
|
||||
f_score = 2 * precision * recall / (precision + recall)
|
||||
except ZeroDivisionError:
|
||||
f_score = 0
|
||||
|
||||
try:
|
||||
seqerr = 1 - float(hit_str_count) / global_accumulative_recall
|
||||
except ZeroDivisionError:
|
||||
seqerr = 1
|
||||
|
||||
try:
|
||||
recall_e2e = float(hit_str_count) / total_num_gt
|
||||
except ZeroDivisionError:
|
||||
recall_e2e = 0
|
||||
|
||||
try:
|
||||
precision_e2e = float(hit_str_count) / total_num_det
|
||||
except ZeroDivisionError:
|
||||
precision_e2e = 0
|
||||
|
||||
try:
|
||||
f_score_e2e = 2 * precision_e2e * recall_e2e / (
|
||||
precision_e2e + recall_e2e)
|
||||
except ZeroDivisionError:
|
||||
f_score_e2e = 0
|
||||
|
||||
final = {
|
||||
'total_num_gt': total_num_gt,
|
||||
'total_num_det': total_num_det,
|
||||
'global_accumulative_recall': global_accumulative_recall,
|
||||
'hit_str_count': hit_str_count,
|
||||
'recall': recall,
|
||||
'precision': precision,
|
||||
'f_score': f_score,
|
||||
'seqerr': seqerr,
|
||||
'recall_e2e': recall_e2e,
|
||||
'precision_e2e': precision_e2e,
|
||||
'f_score_e2e': f_score_e2e
|
||||
}
|
||||
return final
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
from shapely.geometry import Polygon
|
||||
"""
|
||||
:param det_x: [1, N] Xs of detection's vertices
|
||||
:param det_y: [1, N] Ys of detection's vertices
|
||||
:param gt_x: [1, N] Xs of groundtruth's vertices
|
||||
:param gt_y: [1, N] Ys of groundtruth's vertices
|
||||
|
||||
##############
|
||||
All the calculation of 'AREA' in this script is handled by:
|
||||
1) First generating a binary mask with the polygon area filled up with 1's
|
||||
2) Summing up all the 1's
|
||||
"""
|
||||
|
||||
|
||||
def area(x, y):
|
||||
polygon = Polygon(np.stack([x, y], axis=1))
|
||||
return float(polygon.area)
|
||||
|
||||
|
||||
def approx_area_of_intersection(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
This helper determine if both polygons are intersecting with each others with an approximation method.
|
||||
Area of intersection represented by the minimum bounding rectangular [xmin, ymin, xmax, ymax]
|
||||
"""
|
||||
det_ymax = np.max(det_y)
|
||||
det_xmax = np.max(det_x)
|
||||
det_ymin = np.min(det_y)
|
||||
det_xmin = np.min(det_x)
|
||||
|
||||
gt_ymax = np.max(gt_y)
|
||||
gt_xmax = np.max(gt_x)
|
||||
gt_ymin = np.min(gt_y)
|
||||
gt_xmin = np.min(gt_x)
|
||||
|
||||
all_min_ymax = np.minimum(det_ymax, gt_ymax)
|
||||
all_max_ymin = np.maximum(det_ymin, gt_ymin)
|
||||
|
||||
intersect_heights = np.maximum(0.0, (all_min_ymax - all_max_ymin))
|
||||
|
||||
all_min_xmax = np.minimum(det_xmax, gt_xmax)
|
||||
all_max_xmin = np.maximum(det_xmin, gt_xmin)
|
||||
intersect_widths = np.maximum(0.0, (all_min_xmax - all_max_xmin))
|
||||
|
||||
return intersect_heights * intersect_widths
|
||||
|
||||
|
||||
def area_of_intersection(det_x, det_y, gt_x, gt_y):
|
||||
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
|
||||
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
|
||||
return float(p1.intersection(p2).area)
|
||||
|
||||
|
||||
def area_of_union(det_x, det_y, gt_x, gt_y):
|
||||
p1 = Polygon(np.stack([det_x, det_y], axis=1)).buffer(0)
|
||||
p2 = Polygon(np.stack([gt_x, gt_y], axis=1)).buffer(0)
|
||||
return float(p1.union(p2).area)
|
||||
|
||||
|
||||
def iou(det_x, det_y, gt_x, gt_y):
|
||||
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
|
||||
area_of_union(det_x, det_y, gt_x, gt_y) + 1.0)
|
||||
|
||||
|
||||
def iod(det_x, det_y, gt_x, gt_y):
|
||||
"""
|
||||
This helper determine the fraction of intersection area over detection area
|
||||
"""
|
||||
return area_of_intersection(det_x, det_y, gt_x, gt_y) / (
|
||||
area(det_x, det_y) + 1.0)
|
|
@ -0,0 +1,87 @@
|
|||
import paddle
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
|
||||
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
|
||||
"""
|
||||
"""
|
||||
pos_lists_, pos_masks_, label_lists_ = [], [], []
|
||||
img_bs = batch_size
|
||||
ngpu = int(batch_size / img_bs)
|
||||
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
|
||||
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
|
||||
for i in range(ngpu):
|
||||
pos_lists_split.append([])
|
||||
pos_masks_split.append([])
|
||||
label_lists_split.append([])
|
||||
|
||||
for i in range(img_ids.shape[0]):
|
||||
img_id = img_ids[i]
|
||||
gpu_id = int(img_id / img_bs)
|
||||
img_id = img_id % img_bs
|
||||
pos_list = pos_lists[i].copy()
|
||||
pos_list[:, 0] = img_id
|
||||
pos_lists_split[gpu_id].append(pos_list)
|
||||
pos_masks_split[gpu_id].append(pos_masks[i].copy())
|
||||
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
|
||||
# repeat or delete
|
||||
for i in range(ngpu):
|
||||
vp_len = len(pos_lists_split[i])
|
||||
if vp_len <= tcl_bs:
|
||||
for j in range(0, tcl_bs - vp_len):
|
||||
pos_list = pos_lists_split[i][j].copy()
|
||||
pos_lists_split[i].append(pos_list)
|
||||
pos_mask = pos_masks_split[i][j].copy()
|
||||
pos_masks_split[i].append(pos_mask)
|
||||
label_list = copy.deepcopy(label_lists_split[i][j])
|
||||
label_lists_split[i].append(label_list)
|
||||
else:
|
||||
for j in range(0, vp_len - tcl_bs):
|
||||
c_len = len(pos_lists_split[i])
|
||||
pop_id = np.random.permutation(c_len)[0]
|
||||
pos_lists_split[i].pop(pop_id)
|
||||
pos_masks_split[i].pop(pop_id)
|
||||
label_lists_split[i].pop(pop_id)
|
||||
# merge
|
||||
for i in range(ngpu):
|
||||
pos_lists_.extend(pos_lists_split[i])
|
||||
pos_masks_.extend(pos_masks_split[i])
|
||||
label_lists_.extend(label_lists_split[i])
|
||||
return pos_lists_, pos_masks_, label_lists_
|
||||
|
||||
|
||||
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
|
||||
pad_num, tcl_bs):
|
||||
label_list = label_list.numpy()
|
||||
batch, _, _, _ = label_list.shape
|
||||
pos_list = pos_list.numpy()
|
||||
pos_mask = pos_mask.numpy()
|
||||
pos_list_t = []
|
||||
pos_mask_t = []
|
||||
label_list_t = []
|
||||
for i in range(batch):
|
||||
for j in range(max_text_nums):
|
||||
if pos_mask[i, j].any():
|
||||
pos_list_t.append(pos_list[i][j])
|
||||
pos_mask_t.append(pos_mask[i][j])
|
||||
label_list_t.append(label_list[i][j])
|
||||
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
|
||||
label_list_t, tcl_bs)
|
||||
label = []
|
||||
tt = [l.tolist() for l in label_list]
|
||||
for i in range(tcl_bs):
|
||||
k = 0
|
||||
for j in range(max_text_length):
|
||||
if tt[i][j][0] != pad_num:
|
||||
k += 1
|
||||
else:
|
||||
break
|
||||
label.append(k)
|
||||
label = paddle.to_tensor(label)
|
||||
label = paddle.cast(label, dtype='int64')
|
||||
pos_list = paddle.to_tensor(pos_list)
|
||||
pos_mask = paddle.to_tensor(pos_mask)
|
||||
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
|
||||
label_list = paddle.cast(label_list, dtype='int32')
|
||||
return pos_list, pos_mask, label_list, label
|
|
@ -0,0 +1,532 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Contains various CTC decoders."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import cv2
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
from itertools import groupby
|
||||
from skimage.morphology._skeletonize import thin
|
||||
|
||||
|
||||
def get_dict(character_dict_path):
|
||||
character_str = ""
|
||||
with open(character_dict_path, "rb") as fin:
|
||||
lines = fin.readlines()
|
||||
for line in lines:
|
||||
line = line.decode('utf-8').strip("\n").strip("\r\n")
|
||||
character_str += line
|
||||
dict_character = list(character_str)
|
||||
return dict_character
|
||||
|
||||
|
||||
def softmax(logits):
|
||||
"""
|
||||
logits: N x d
|
||||
"""
|
||||
max_value = np.max(logits, axis=1, keepdims=True)
|
||||
exp = np.exp(logits - max_value)
|
||||
exp_sum = np.sum(exp, axis=1, keepdims=True)
|
||||
dist = exp / exp_sum
|
||||
return dist
|
||||
|
||||
|
||||
def get_keep_pos_idxs(labels, remove_blank=None):
|
||||
"""
|
||||
Remove duplicate and get pos idxs of keep items.
|
||||
The value of keep_blank should be [None, 95].
|
||||
"""
|
||||
duplicate_len_list = []
|
||||
keep_pos_idx_list = []
|
||||
keep_char_idx_list = []
|
||||
for k, v_ in groupby(labels):
|
||||
current_len = len(list(v_))
|
||||
if k != remove_blank:
|
||||
current_idx = int(sum(duplicate_len_list) + current_len // 2)
|
||||
keep_pos_idx_list.append(current_idx)
|
||||
keep_char_idx_list.append(k)
|
||||
duplicate_len_list.append(current_len)
|
||||
return keep_char_idx_list, keep_pos_idx_list
|
||||
|
||||
|
||||
def remove_blank(labels, blank=0):
|
||||
new_labels = [x for x in labels if x != blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def insert_blank(labels, blank=0):
|
||||
new_labels = [blank]
|
||||
for l in labels:
|
||||
new_labels += [l, blank]
|
||||
return new_labels
|
||||
|
||||
|
||||
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC greedy (best path) decoder.
|
||||
"""
|
||||
raw_str = np.argmax(np.array(probs_seq), axis=1)
|
||||
remove_blank_in_pos = None if keep_blank_in_idxs else blank
|
||||
dedup_str, keep_idx_list = get_keep_pos_idxs(
|
||||
raw_str, remove_blank=remove_blank_in_pos)
|
||||
dst_str = remove_blank(dedup_str, blank=blank)
|
||||
return dst_str, keep_idx_list
|
||||
|
||||
|
||||
def instance_ctc_greedy_decoder(gather_info,
|
||||
logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
gather_info: [[x, y], [x, y] ...]
|
||||
logits_map: H x W X (n_chars + 1)
|
||||
"""
|
||||
_, _, C = logits_map.shape
|
||||
ys, xs = zip(*gather_info)
|
||||
logits_seq = logits_map[list(ys), list(xs)] # n x 96
|
||||
probs_seq = softmax(logits_seq)
|
||||
dst_str, keep_idx_list = ctc_greedy_decoder(
|
||||
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
|
||||
return dst_str, keep_gather_list
|
||||
|
||||
|
||||
def ctc_decoder_for_image(gather_info_list, logits_map,
|
||||
keep_blank_in_idxs=True):
|
||||
"""
|
||||
CTC decoder using multiple processes.
|
||||
"""
|
||||
decoder_results = []
|
||||
for gather_info in gather_info_list:
|
||||
res = instance_ctc_greedy_decoder(
|
||||
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
|
||||
decoder_results.append(res)
|
||||
return decoder_results
|
||||
|
||||
|
||||
def sort_with_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
|
||||
def sort_part_with_direction(pos_list, point_direction):
|
||||
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
|
||||
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list, sorted_direction
|
||||
|
||||
pos_list = np.array(pos_list).reshape(-1, 2)
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||
point_direction)
|
||||
|
||||
point_num = len(sorted_point)
|
||||
if point_num >= 16:
|
||||
middle_num = point_num // 2
|
||||
first_part_point = sorted_point[:middle_num]
|
||||
first_point_direction = sorted_direction[:middle_num]
|
||||
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||
first_part_point, first_point_direction)
|
||||
|
||||
last_part_point = sorted_point[middle_num:]
|
||||
last_point_direction = sorted_direction[middle_num:]
|
||||
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||
last_part_point, last_point_direction)
|
||||
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||
|
||||
return sorted_point, np.array(sorted_direction)
|
||||
|
||||
|
||||
def add_id(pos_list, image_id=0):
|
||||
"""
|
||||
Add id for gather feature, for inference.
|
||||
"""
|
||||
new_list = []
|
||||
for item in pos_list:
|
||||
new_list.append((image_id, item[0], item[1]))
|
||||
return new_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||
|
||||
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||
left_average_len = np.linalg.norm(left_average_direction)
|
||||
left_start = np.array(sorted_list[0])
|
||||
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||
|
||||
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||
right_average_len = np.linalg.norm(right_average_direction)
|
||||
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||
right_start = np.array(sorted_list[-1])
|
||||
|
||||
append_num = max(
|
||||
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||
left_list = []
|
||||
right_list = []
|
||||
for i in range(append_num):
|
||||
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||
left_list.append((ly, lx))
|
||||
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||
right_list.append((ry, rx))
|
||||
|
||||
all_list = left_list[::-1] + sorted_list + right_list
|
||||
return all_list
|
||||
|
||||
|
||||
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
binary_tcl_map: h x w
|
||||
"""
|
||||
h, w, _ = f_direction.shape
|
||||
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
|
||||
|
||||
# expand along
|
||||
point_num = len(sorted_list)
|
||||
sub_direction_len = max(point_num // 3, 2)
|
||||
left_direction = point_direction[:sub_direction_len, :]
|
||||
right_dirction = point_direction[point_num - sub_direction_len:, :]
|
||||
|
||||
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
|
||||
left_average_len = np.linalg.norm(left_average_direction)
|
||||
left_start = np.array(sorted_list[0])
|
||||
left_step = left_average_direction / (left_average_len + 1e-6)
|
||||
|
||||
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
|
||||
right_average_len = np.linalg.norm(right_average_direction)
|
||||
right_step = right_average_direction / (right_average_len + 1e-6)
|
||||
right_start = np.array(sorted_list[-1])
|
||||
|
||||
append_num = max(
|
||||
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
|
||||
max_append_num = 2 * append_num
|
||||
|
||||
left_list = []
|
||||
right_list = []
|
||||
for i in range(max_append_num):
|
||||
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ly < h and lx < w and (ly, lx) not in left_list:
|
||||
if binary_tcl_map[ly, lx] > 0.5:
|
||||
left_list.append((ly, lx))
|
||||
else:
|
||||
break
|
||||
|
||||
for i in range(max_append_num):
|
||||
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
|
||||
'int32').tolist()
|
||||
if ry < h and rx < w and (ry, rx) not in right_list:
|
||||
if binary_tcl_map[ry, rx] > 0.5:
|
||||
right_list.append((ry, rx))
|
||||
else:
|
||||
break
|
||||
|
||||
all_list = left_list[::-1] + sorted_list + right_list
|
||||
return all_list
|
||||
|
||||
|
||||
def generate_pivot_list_curved(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_expand=True,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
|
||||
if is_expand:
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
else:
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list_horizontal(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map_bi = (p_score > score_thresh) * 1.0
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
p_tcl_map_bi.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
center_pos_yxs = []
|
||||
end_points_yxs = []
|
||||
instance_center_pos_yxs = []
|
||||
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 5:
|
||||
continue
|
||||
|
||||
# add rule here
|
||||
main_direction = extract_main_direction(pos_list,
|
||||
f_direction) # y x
|
||||
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
|
||||
is_h_angle = abs(np.sum(
|
||||
main_direction * reference_directin)) < math.cos(math.pi / 180 *
|
||||
70)
|
||||
|
||||
point_yxs = np.array(pos_list)
|
||||
max_y, max_x = np.max(point_yxs, axis=0)
|
||||
min_y, min_x = np.min(point_yxs, axis=0)
|
||||
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
|
||||
|
||||
pos_list_final = []
|
||||
if is_h_len:
|
||||
xs = np.unique(xs)
|
||||
for x in xs:
|
||||
ys = instance_label_map[:, x].copy().reshape((-1, ))
|
||||
y = int(np.where(ys == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
else:
|
||||
ys = np.unique(ys)
|
||||
for y in ys:
|
||||
xs = instance_label_map[y, :].copy().reshape((-1, ))
|
||||
x = int(np.where(xs == instance_id)[0].mean())
|
||||
pos_list_final.append((y, x))
|
||||
|
||||
pos_list_sorted, _ = sort_with_direction(pos_list_final,
|
||||
f_direction)
|
||||
all_pos_yxs.append(pos_list_sorted)
|
||||
|
||||
# use decoder to filter backgroud points.
|
||||
p_char_maps = p_char_maps.transpose([1, 2, 0])
|
||||
decode_res = ctc_decoder_for_image(
|
||||
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
|
||||
for decoded_str, keep_yxs_list in decode_res:
|
||||
if is_backbone:
|
||||
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
|
||||
instance_center_pos_yxs.append(keep_yxs_list_with_id)
|
||||
else:
|
||||
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
|
||||
center_pos_yxs.extend(keep_yxs_list)
|
||||
|
||||
if is_backbone:
|
||||
return instance_center_pos_yxs
|
||||
else:
|
||||
return center_pos_yxs, end_points_yxs
|
||||
|
||||
|
||||
def generate_pivot_list(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
Warp all the function together.
|
||||
"""
|
||||
if is_curved:
|
||||
return generate_pivot_list_curved(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_expand=True,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
else:
|
||||
return generate_pivot_list_horizontal(
|
||||
p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=score_thresh,
|
||||
is_backbone=is_backbone,
|
||||
image_id=image_id)
|
||||
|
||||
|
||||
# for refine module
|
||||
def extract_main_direction(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
pos_list = np.array(pos_list)
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
average_direction = average_direction / (
|
||||
np.linalg.norm(average_direction) + 1e-6)
|
||||
return average_direction
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
|
||||
"""
|
||||
pos_list_full = np.array(pos_list).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list
|
||||
|
||||
|
||||
def sort_by_direction_with_image_id(pos_list, f_direction):
|
||||
"""
|
||||
f_direction: h x w x 2
|
||||
pos_list: [[y, x], [y, x], [y, x] ...]
|
||||
"""
|
||||
|
||||
def sort_part_with_direction(pos_list_full, point_direction):
|
||||
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
|
||||
pos_list = pos_list_full[:, 1:]
|
||||
point_direction = np.array(point_direction).reshape(-1, 2)
|
||||
average_direction = np.mean(point_direction, axis=0, keepdims=True)
|
||||
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
|
||||
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
|
||||
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
|
||||
return sorted_list, sorted_direction
|
||||
|
||||
pos_list = np.array(pos_list).reshape(-1, 3)
|
||||
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
|
||||
point_direction = point_direction[:, ::-1] # x, y -> y, x
|
||||
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
|
||||
point_direction)
|
||||
|
||||
point_num = len(sorted_point)
|
||||
if point_num >= 16:
|
||||
middle_num = point_num // 2
|
||||
first_part_point = sorted_point[:middle_num]
|
||||
first_point_direction = sorted_direction[:middle_num]
|
||||
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
|
||||
first_part_point, first_point_direction)
|
||||
|
||||
last_part_point = sorted_point[middle_num:]
|
||||
last_point_direction = sorted_direction[middle_num:]
|
||||
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
|
||||
last_part_point, last_point_direction)
|
||||
sorted_point = sorted_fist_part_point + sorted_last_part_point
|
||||
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
|
||||
|
||||
return sorted_point
|
||||
|
||||
|
||||
def generate_pivot_list_tt_inference(p_score,
|
||||
p_char_maps,
|
||||
f_direction,
|
||||
score_thresh=0.5,
|
||||
is_backbone=False,
|
||||
is_curved=True,
|
||||
image_id=0):
|
||||
"""
|
||||
return center point and end point of TCL instance; filter with the char maps;
|
||||
"""
|
||||
p_score = p_score[0]
|
||||
f_direction = f_direction.transpose(1, 2, 0)
|
||||
p_tcl_map = (p_score > score_thresh) * 1.0
|
||||
skeleton_map = thin(p_tcl_map)
|
||||
instance_count, instance_label_map = cv2.connectedComponents(
|
||||
skeleton_map.astype(np.uint8), connectivity=8)
|
||||
|
||||
# get TCL Instance
|
||||
all_pos_yxs = []
|
||||
if instance_count > 0:
|
||||
for instance_id in range(1, instance_count):
|
||||
pos_list = []
|
||||
ys, xs = np.where(instance_label_map == instance_id)
|
||||
pos_list = list(zip(ys, xs))
|
||||
### FIX-ME, eliminate outlier
|
||||
if len(pos_list) < 3:
|
||||
continue
|
||||
pos_list_sorted = sort_and_expand_with_direction_v2(
|
||||
pos_list, f_direction, p_tcl_map)
|
||||
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
|
||||
all_pos_yxs.append(pos_list_sorted_with_id)
|
||||
return all_pos_yxs
|
|
@ -0,0 +1,162 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import numpy as np
|
||||
import cv2
|
||||
import time
|
||||
|
||||
|
||||
def resize_image(im, max_side_len=512):
|
||||
"""
|
||||
resize image to a size multiple of max_stride which is required by the network
|
||||
:param im: the resized image
|
||||
:param max_side_len: limit of max image size to avoid out of memory in gpu
|
||||
:return: the resized image and the resize ratio
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h > resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def resize_image_min(im, max_side_len=512):
|
||||
"""
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
|
||||
if resize_h < resize_w:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
else:
|
||||
ratio = float(max_side_len) / resize_w
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def resize_image_for_totaltext(im, max_side_len=512):
|
||||
"""
|
||||
"""
|
||||
h, w, _ = im.shape
|
||||
|
||||
resize_w = w
|
||||
resize_h = h
|
||||
ratio = 1.25
|
||||
if h * ratio > max_side_len:
|
||||
ratio = float(max_side_len) / resize_h
|
||||
|
||||
resize_h = int(resize_h * ratio)
|
||||
resize_w = int(resize_w * ratio)
|
||||
|
||||
max_stride = 128
|
||||
resize_h = (resize_h + max_stride - 1) // max_stride * max_stride
|
||||
resize_w = (resize_w + max_stride - 1) // max_stride * max_stride
|
||||
im = cv2.resize(im, (int(resize_w), int(resize_h)))
|
||||
ratio_h = resize_h / float(h)
|
||||
ratio_w = resize_w / float(w)
|
||||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
def point_pair2poly(point_pair_list):
|
||||
"""
|
||||
Transfer vertical point_pairs into poly point in clockwise.
|
||||
"""
|
||||
pair_length_list = []
|
||||
for point_pair in point_pair_list:
|
||||
pair_length = np.linalg.norm(point_pair[0] - point_pair[1])
|
||||
pair_length_list.append(pair_length)
|
||||
pair_length_list = np.array(pair_length_list)
|
||||
pair_info = (pair_length_list.max(), pair_length_list.min(),
|
||||
pair_length_list.mean())
|
||||
|
||||
point_num = len(point_pair_list) * 2
|
||||
point_list = [0] * point_num
|
||||
for idx, point_pair in enumerate(point_pair_list):
|
||||
point_list[idx] = point_pair[0]
|
||||
point_list[point_num - 1 - idx] = point_pair[1]
|
||||
return np.array(point_list).reshape(-1, 2), pair_info
|
||||
|
||||
|
||||
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
|
||||
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
|
||||
"""
|
||||
expand poly along width.
|
||||
"""
|
||||
point_num = poly.shape[0]
|
||||
left_quad = np.array(
|
||||
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
|
||||
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
|
||||
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
|
||||
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
|
||||
right_quad = np.array(
|
||||
[
|
||||
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
|
||||
poly[point_num // 2], poly[point_num // 2 + 1]
|
||||
],
|
||||
dtype=np.float32)
|
||||
right_ratio = 1.0 + \
|
||||
shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
|
||||
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
|
||||
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
|
||||
poly[0] = left_quad_expand[0]
|
||||
poly[-1] = left_quad_expand[-1]
|
||||
poly[point_num // 2 - 1] = right_quad_expand[1]
|
||||
poly[point_num // 2] = right_quad_expand[2]
|
||||
return poly
|
||||
|
||||
|
||||
def norm2(x, axis=None):
|
||||
if axis:
|
||||
return np.sqrt(np.sum(x**2, axis=axis))
|
||||
return np.sqrt(np.sum(x**2))
|
||||
|
||||
|
||||
def cos(p1, p2):
|
||||
return (p1 * p2).sum() / (norm2(p1) * norm2(p2))
|
|
@ -1,9 +1,11 @@
|
|||
shapely
|
||||
imgaug
|
||||
scikit-image==0.17.2
|
||||
imgaug==0.4.0
|
||||
pyclipper
|
||||
lmdb
|
||||
opencv-python==4.2.0.32
|
||||
tqdm
|
||||
numpy
|
||||
visualdl
|
||||
python-Levenshtein
|
||||
python-Levenshtein
|
||||
opencv-contrib-python
|
2
setup.py
|
@ -32,7 +32,7 @@ setup(
|
|||
package_dir={'paddleocr': ''},
|
||||
include_package_data=True,
|
||||
entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]},
|
||||
version='2.0.2',
|
||||
version='2.0.3',
|
||||
install_requires=requirements,
|
||||
license='Apache License 2.0',
|
||||
description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices',
|
||||
|
|
|
@ -47,6 +47,7 @@ def main():
|
|||
config['Architecture']["Head"]['out_channels'] = len(
|
||||
getattr(post_process_class, 'character'))
|
||||
model = build_model(config['Architecture'])
|
||||
use_srn = config['Architecture']['algorithm'] == "SRN"
|
||||
|
||||
best_model_dict = init_model(config, model, logger)
|
||||
if len(best_model_dict):
|
||||
|
@ -59,7 +60,7 @@ def main():
|
|||
|
||||
# start eval
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
eval_class, use_srn)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metirc.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
|
|
@ -98,10 +98,10 @@ class TextClassifier(object):
|
|||
norm_img_batch = np.concatenate(norm_img_batch)
|
||||
norm_img_batch = norm_img_batch.copy()
|
||||
starttime = time.time()
|
||||
|
||||
self.input_tensor.copy_from_cpu(norm_img_batch)
|
||||
self.predictor.run()
|
||||
prob_out = self.output_tensors[0].copy_to_cpu()
|
||||
self.predictor.try_shrink_memory()
|
||||
cls_result = self.postprocess_op(prob_out)
|
||||
elapse += time.time() - starttime
|
||||
for rno in range(len(cls_result)):
|
||||
|
|