Merge branch 'update_requirements' of https://github.com/WenmuZhou/PaddleOCR into update_requirements
This commit is contained in:
commit
7b7c8f3bb7
|
@ -147,6 +147,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
self.itemsToShapesbox = {}
|
self.itemsToShapesbox = {}
|
||||||
self.shapesToItemsbox = {}
|
self.shapesToItemsbox = {}
|
||||||
self.prevLabelText = getStr('tempLabel')
|
self.prevLabelText = getStr('tempLabel')
|
||||||
|
self.noLabelText = getStr('nullLabel')
|
||||||
self.model = 'paddle'
|
self.model = 'paddle'
|
||||||
self.PPreader = None
|
self.PPreader = None
|
||||||
self.autoSaveNum = 5
|
self.autoSaveNum = 5
|
||||||
|
@ -1020,7 +1021,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
|
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
|
||||||
self.updateComboBox()
|
self.updateComboBox()
|
||||||
|
|
||||||
def updateComboBox(self): # TODO:貌似没用
|
def updateComboBox(self):
|
||||||
# Get the unique labels and add them to the Combobox.
|
# Get the unique labels and add them to the Combobox.
|
||||||
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
|
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
|
||||||
|
|
||||||
|
@ -1040,7 +1041,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
return dict(label=s.label, # str
|
return dict(label=s.label, # str
|
||||||
line_color=s.line_color.getRgb(),
|
line_color=s.line_color.getRgb(),
|
||||||
fill_color=s.fill_color.getRgb(),
|
fill_color=s.fill_color.getRgb(),
|
||||||
points=[(p.x(), p.y()) for p in s.points], # QPonitF
|
points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
|
||||||
# add chris
|
# add chris
|
||||||
difficult=s.difficult) # bool
|
difficult=s.difficult) # bool
|
||||||
|
|
||||||
|
@ -1069,7 +1070,7 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
|
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
|
||||||
return True
|
return True
|
||||||
except:
|
except:
|
||||||
self.errorMessage(u'Error saving label data')
|
self.errorMessage(u'Error saving label data', u'Error saving label data')
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def copySelectedShape(self):
|
def copySelectedShape(self):
|
||||||
|
@ -1802,7 +1803,11 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
result.insert(0, box)
|
result.insert(0, box)
|
||||||
print('result in reRec is ', result)
|
print('result in reRec is ', result)
|
||||||
self.result_dic.append(result)
|
self.result_dic.append(result)
|
||||||
if result[1][0] == shape.label:
|
else:
|
||||||
|
print('Can not recognise the box')
|
||||||
|
self.result_dic.append([box,(self.noLabelText,0)])
|
||||||
|
|
||||||
|
if self.noLabelText == shape.label or result[1][0] == shape.label:
|
||||||
print('label no change')
|
print('label no change')
|
||||||
else:
|
else:
|
||||||
rec_flag += 1
|
rec_flag += 1
|
||||||
|
@ -1836,9 +1841,14 @@ class MainWindow(QMainWindow, WindowMixin):
|
||||||
print('label no change')
|
print('label no change')
|
||||||
else:
|
else:
|
||||||
shape.label = result[1][0]
|
shape.label = result[1][0]
|
||||||
|
else:
|
||||||
|
print('Can not recognise the box')
|
||||||
|
if self.noLabelText == shape.label:
|
||||||
|
print('label no change')
|
||||||
|
else:
|
||||||
|
shape.label = self.noLabelText
|
||||||
self.singleLabel(shape)
|
self.singleLabel(shape)
|
||||||
self.setDirty()
|
self.setDirty()
|
||||||
print(box)
|
|
||||||
|
|
||||||
def autolcm(self):
|
def autolcm(self):
|
||||||
vbox = QVBoxLayout()
|
vbox = QVBoxLayout()
|
||||||
|
|
|
@ -45,7 +45,7 @@ class Canvas(QWidget):
|
||||||
CREATE, EDIT = list(range(2))
|
CREATE, EDIT = list(range(2))
|
||||||
_fill_drawing = False # draw shadows
|
_fill_drawing = False # draw shadows
|
||||||
|
|
||||||
epsilon = 11.0
|
epsilon = 5.0
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(Canvas, self).__init__(*args, **kwargs)
|
super(Canvas, self).__init__(*args, **kwargs)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -87,6 +87,7 @@ creatPolygon=四点标注
|
||||||
drawSquares=正方形标注
|
drawSquares=正方形标注
|
||||||
saveRec=保存识别结果
|
saveRec=保存识别结果
|
||||||
tempLabel=待识别
|
tempLabel=待识别
|
||||||
|
nullLabel=无法识别
|
||||||
steps=操作步骤
|
steps=操作步骤
|
||||||
choseModelLg=选择模型语言
|
choseModelLg=选择模型语言
|
||||||
cancel=取消
|
cancel=取消
|
||||||
|
|
|
@ -77,7 +77,7 @@ IR=Image Resize
|
||||||
autoRecognition=Auto Recognition
|
autoRecognition=Auto Recognition
|
||||||
reRecognition=Re-recognition
|
reRecognition=Re-recognition
|
||||||
mfile=File
|
mfile=File
|
||||||
medit=Eidt
|
medit=Edit
|
||||||
mview=View
|
mview=View
|
||||||
mhelp=Help
|
mhelp=Help
|
||||||
iconList=Icon List
|
iconList=Icon List
|
||||||
|
@ -87,6 +87,7 @@ creatPolygon=Create Quadrilateral
|
||||||
drawSquares=Draw Squares
|
drawSquares=Draw Squares
|
||||||
saveRec=Save Recognition Result
|
saveRec=Save Recognition Result
|
||||||
tempLabel=TEMPORARY
|
tempLabel=TEMPORARY
|
||||||
|
nullLabel=NULL
|
||||||
steps=Steps
|
steps=Steps
|
||||||
choseModelLg=Choose Model Language
|
choseModelLg=Choose Model Language
|
||||||
cancel=Cancel
|
cancel=Cancel
|
||||||
|
|
|
@ -32,7 +32,8 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800">
|
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800">
|
||||||
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/00018069.jpg" width="800">
|
<img src="doc/imgs_results/multi_lang/img_01.jpg" width="800">
|
||||||
|
<img src="doc/imgs_results/multi_lang/img_02.jpg" width="800">
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md).
|
The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md).
|
||||||
|
|
|
@ -62,20 +62,21 @@ PostProcess:
|
||||||
mode: fast # fast or slow two ways
|
mode: fast # fast or slow two ways
|
||||||
Metric:
|
Metric:
|
||||||
name: E2EMetric
|
name: E2EMetric
|
||||||
gt_mat_dir: # the dir of gt_mat
|
gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
|
||||||
character_dict_path: ppocr/utils/ic15_dict.txt
|
character_dict_path: ppocr/utils/ic15_dict.txt
|
||||||
main_indicator: f_score_e2e
|
main_indicator: f_score_e2e
|
||||||
|
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: PGDataSet
|
name: PGDataSet
|
||||||
label_file_list: [.././train_data/total_text/train/]
|
data_dir: ./train_data/total_text/train
|
||||||
|
label_file_list: [./train_data/total_text/train/]
|
||||||
ratio_list: [1.0]
|
ratio_list: [1.0]
|
||||||
data_format: icdar #two data format: icdar/textnet
|
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
img_mode: BGR
|
img_mode: BGR
|
||||||
channel_first: False
|
channel_first: False
|
||||||
|
- E2ELabelEncode:
|
||||||
- PGProcessTrain:
|
- PGProcessTrain:
|
||||||
batch_size: 14 # same as loader: batch_size_per_card
|
batch_size: 14 # same as loader: batch_size_per_card
|
||||||
min_crop_size: 24
|
min_crop_size: 24
|
||||||
|
@ -92,13 +93,12 @@ Train:
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: PGDataSet
|
name: PGDataSet
|
||||||
data_dir: ./train_data/
|
data_dir: ./train_data/total_text/test
|
||||||
label_file_list: [./train_data/total_text/test/]
|
label_file_list: [./train_data/total_text/test/]
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
img_mode: RGB
|
img_mode: RGB
|
||||||
channel_first: False
|
channel_first: False
|
||||||
- E2ELabelEncode:
|
|
||||||
- E2EResizeForTest:
|
- E2EResizeForTest:
|
||||||
max_side_len: 768
|
max_side_len: 768
|
||||||
- NormalizeImage:
|
- NormalizeImage:
|
||||||
|
@ -108,7 +108,7 @@ Eval:
|
||||||
order: 'hwc'
|
order: 'hwc'
|
||||||
- ToCHWImage:
|
- ToCHWImage:
|
||||||
- KeepKeys:
|
- KeepKeys:
|
||||||
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id']
|
keep_keys: [ 'image', 'shape', 'img_id']
|
||||||
loader:
|
loader:
|
||||||
shuffle: False
|
shuffle: False
|
||||||
drop_last: False
|
drop_last: False
|
||||||
|
|
|
@ -118,7 +118,6 @@ class ArgsParser(ArgumentParser):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _set_language(self, type):
|
def _set_language(self, type):
|
||||||
print("type:", type)
|
|
||||||
lang = type[0]
|
lang = type[0]
|
||||||
assert (type), "please use -l or --language to choose language type"
|
assert (type), "please use -l or --language to choose language type"
|
||||||
assert(
|
assert(
|
||||||
|
|
|
@ -113,7 +113,7 @@ python3 generate_multi_language_configs.py -l it \
|
||||||
| cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
|
| cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
|
||||||
| devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
|
| devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
|
||||||
|
|
||||||
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
|
更多支持语种请参考: [多语言模型](./multi_languages.md)
|
||||||
|
|
||||||
|
|
||||||
<a name="文本方向分类模型"></a>
|
<a name="文本方向分类模型"></a>
|
||||||
|
|
|
@ -134,7 +134,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
|
||||||
<a name="python_脚本运行"></a>
|
<a name="python_脚本运行"></a>
|
||||||
### 2.2 python 脚本运行
|
### 2.2 python 脚本运行
|
||||||
|
|
||||||
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中:
|
ppocr 也支持在python脚本中运行,便于嵌入到您自己的代码中 :
|
||||||
|
|
||||||
* 整图预测(检测+识别)
|
* 整图预测(检测+识别)
|
||||||
|
|
||||||
|
@ -155,7 +155,7 @@ image = Image.open(img_path).convert('RGB')
|
||||||
boxes = [line[0] for line in result]
|
boxes = [line[0] for line in result]
|
||||||
txts = [line[1][0] for line in result]
|
txts = [line[1][0] for line in result]
|
||||||
scores = [line[1][1] for line in result]
|
scores = [line[1][1] for line in result]
|
||||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
|
||||||
im_show = Image.fromarray(im_show)
|
im_show = Image.fromarray(im_show)
|
||||||
im_show.save('result.jpg')
|
im_show.save('result.jpg')
|
||||||
```
|
```
|
||||||
|
@ -240,7 +240,7 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|
||||||
|德文|german|german|
|
|德文|german|german|
|
||||||
|日文|japan|japan|
|
|日文|japan|japan|
|
||||||
|韩文|korean|korean|
|
|韩文|korean|korean|
|
||||||
|中文繁体|chinese traditional |ch_tra|
|
|中文繁体|chinese traditional |chinese_cht|
|
||||||
|意大利文| Italian |it|
|
|意大利文| Italian |it|
|
||||||
|西班牙文|Spanish |es|
|
|西班牙文|Spanish |es|
|
||||||
|葡萄牙文| Portuguese|pt|
|
|葡萄牙文| Portuguese|pt|
|
||||||
|
@ -259,7 +259,6 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|
||||||
|乌克兰文|Ukranian|uk|
|
|乌克兰文|Ukranian|uk|
|
||||||
|白俄罗斯文|Belarusian|be|
|
|白俄罗斯文|Belarusian|be|
|
||||||
|泰卢固文|Telugu |te|
|
|泰卢固文|Telugu |te|
|
||||||
|卡纳达文|Kannada |kn|
|
|
||||||
|泰米尔文|Tamil |ta|
|
|泰米尔文|Tamil |ta|
|
||||||
|南非荷兰文 |Afrikaans |af|
|
|南非荷兰文 |Afrikaans |af|
|
||||||
|阿塞拜疆文 |Azerbaijani |az|
|
|阿塞拜疆文 |Azerbaijani |az|
|
||||||
|
|
|
@ -111,7 +111,7 @@ python3 generate_multi_language_configs.py -l it \
|
||||||
| cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
|
| cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
|
||||||
| devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
|
| devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
|
||||||
|
|
||||||
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
|
For more supported languages, please refer to : [Multi-language model](./multi_languages_en.md)
|
||||||
|
|
||||||
|
|
||||||
<a name="Angle"></a>
|
<a name="Angle"></a>
|
||||||
|
|
|
@ -153,7 +153,7 @@ image = Image.open(img_path).convert('RGB')
|
||||||
boxes = [line[0] for line in result]
|
boxes = [line[0] for line in result]
|
||||||
txts = [line[1][0] for line in result]
|
txts = [line[1][0] for line in result]
|
||||||
scores = [line[1][1] for line in result]
|
scores = [line[1][1] for line in result]
|
||||||
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf')
|
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
|
||||||
im_show = Image.fromarray(im_show)
|
im_show = Image.fromarray(im_show)
|
||||||
im_show.save('result.jpg')
|
im_show.save('result.jpg')
|
||||||
```
|
```
|
||||||
|
@ -232,7 +232,7 @@ For functions such as data annotation, you can read the complete [Document Tutor
|
||||||
|german|german|
|
|german|german|
|
||||||
|japan|japan|
|
|japan|japan|
|
||||||
|korean|korean|
|
|korean|korean|
|
||||||
|chinese traditional |ch_tra|
|
|chinese traditional |chinese_cht|
|
||||||
| Italian |it|
|
| Italian |it|
|
||||||
|Spanish |es|
|
|Spanish |es|
|
||||||
| Portuguese|pt|
|
| Portuguese|pt|
|
||||||
|
@ -251,7 +251,6 @@ For functions such as data annotation, you can read the complete [Document Tutor
|
||||||
|Ukranian|uk|
|
|Ukranian|uk|
|
||||||
|Belarusian|be|
|
|Belarusian|be|
|
||||||
|Telugu |te|
|
|Telugu |te|
|
||||||
|Kannada |kn|
|
|
||||||
|Tamil |ta|
|
|Tamil |ta|
|
||||||
|Afrikaans |af|
|
|Afrikaans |af|
|
||||||
|Azerbaijani |az|
|
|Azerbaijani |az|
|
||||||
|
|
Binary file not shown.
After Width: | Height: | Size: 107 KiB |
Binary file not shown.
After Width: | Height: | Size: 231 KiB |
14
paddleocr.py
14
paddleocr.py
|
@ -30,6 +30,7 @@ from ppocr.utils.logging import get_logger
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||||
|
from tools.infer.utility import draw_ocr
|
||||||
|
|
||||||
__all__ = ['PaddleOCR']
|
__all__ = ['PaddleOCR']
|
||||||
|
|
||||||
|
@ -117,7 +118,7 @@ model_urls = {
|
||||||
}
|
}
|
||||||
|
|
||||||
SUPPORT_DET_MODEL = ['DB']
|
SUPPORT_DET_MODEL = ['DB']
|
||||||
VERSION = 2.1
|
VERSION = '2.1'
|
||||||
SUPPORT_REC_MODEL = ['CRNN']
|
SUPPORT_REC_MODEL = ['CRNN']
|
||||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||||
|
|
||||||
|
@ -315,14 +316,13 @@ class PaddleOCR(predict_system.TextSystem):
|
||||||
|
|
||||||
# init model dir
|
# init model dir
|
||||||
if postprocess_params.det_model_dir is None:
|
if postprocess_params.det_model_dir is None:
|
||||||
postprocess_params.det_model_dir = os.path.join(
|
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||||
BASE_DIR, '{}/det/{}'.format(VERSION, det_lang))
|
'det', det_lang)
|
||||||
if postprocess_params.rec_model_dir is None:
|
if postprocess_params.rec_model_dir is None:
|
||||||
postprocess_params.rec_model_dir = os.path.join(
|
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||||
BASE_DIR, '{}/rec/{}'.format(VERSION, lang))
|
'rec', lang)
|
||||||
if postprocess_params.cls_model_dir is None:
|
if postprocess_params.cls_model_dir is None:
|
||||||
postprocess_params.cls_model_dir = os.path.join(
|
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||||
BASE_DIR, '{}/cls'.format(VERSION))
|
|
||||||
print(postprocess_params)
|
print(postprocess_params)
|
||||||
# download model
|
# download model
|
||||||
maybe_download(postprocess_params.det_model_dir,
|
maybe_download(postprocess_params.det_model_dir,
|
||||||
|
|
|
@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
|
||||||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
|
||||||
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
|
||||||
'mr', 'ne'
|
'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
|
||||||
return dict_character
|
return dict_character
|
||||||
|
|
||||||
|
|
||||||
class E2ELabelEncode(BaseRecLabelEncode):
|
class E2ELabelEncode(object):
|
||||||
def __init__(self,
|
def __init__(self, **kwargs):
|
||||||
max_text_length,
|
pass
|
||||||
character_dict_path=None,
|
|
||||||
character_type='EN',
|
|
||||||
use_space_char=False,
|
|
||||||
**kwargs):
|
|
||||||
super(E2ELabelEncode,
|
|
||||||
self).__init__(max_text_length, character_dict_path,
|
|
||||||
character_type, use_space_char)
|
|
||||||
self.pad_num = len(self.dict) # the length to pad
|
|
||||||
|
|
||||||
def __call__(self, data):
|
def __call__(self, data):
|
||||||
texts = data['strs']
|
import json
|
||||||
temp_texts = []
|
label = data['label']
|
||||||
for text in texts:
|
label = json.loads(label)
|
||||||
text = text.lower()
|
nBox = len(label)
|
||||||
text = self.encode(text)
|
boxes, txts, txt_tags = [], [], []
|
||||||
if text is None:
|
for bno in range(0, nBox):
|
||||||
return None
|
box = label[bno]['points']
|
||||||
text = text + [self.pad_num] * (self.max_text_len - len(text))
|
txt = label[bno]['transcription']
|
||||||
temp_texts.append(text)
|
boxes.append(box)
|
||||||
data['strs'] = np.array(temp_texts)
|
txts.append(txt)
|
||||||
|
if txt in ['*', '###']:
|
||||||
|
txt_tags.append(True)
|
||||||
|
else:
|
||||||
|
txt_tags.append(False)
|
||||||
|
boxes = np.array(boxes, dtype=np.float32)
|
||||||
|
txt_tags = np.array(txt_tags, dtype=np.bool)
|
||||||
|
|
||||||
|
data['polys'] = boxes
|
||||||
|
data['texts'] = txts
|
||||||
|
data['ignore_tags'] = txt_tags
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ class PGProcessTrain(object):
|
||||||
|
|
||||||
return min_area_quad
|
return min_area_quad
|
||||||
|
|
||||||
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme):
|
def check_and_validate_polys(self, polys, tags, im_size):
|
||||||
"""
|
"""
|
||||||
check so that the text poly is in the same direction,
|
check so that the text poly is in the same direction,
|
||||||
and also filter some invalid polygons
|
and also filter some invalid polygons
|
||||||
|
@ -96,7 +96,7 @@ class PGProcessTrain(object):
|
||||||
:param tags:
|
:param tags:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
(h, w) = xxx_todo_changeme
|
(h, w) = im_size
|
||||||
if polys.shape[0] == 0:
|
if polys.shape[0] == 0:
|
||||||
return polys, np.array([]), np.array([])
|
return polys, np.array([]), np.array([])
|
||||||
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
|
||||||
|
@ -750,8 +750,8 @@ class PGProcessTrain(object):
|
||||||
input_size = 512
|
input_size = 512
|
||||||
im = data['image']
|
im = data['image']
|
||||||
text_polys = data['polys']
|
text_polys = data['polys']
|
||||||
text_tags = data['tags']
|
text_tags = data['ignore_tags']
|
||||||
text_strs = data['strs']
|
text_strs = data['texts']
|
||||||
h, w, _ = im.shape
|
h, w, _ = im.shape
|
||||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||||
text_polys, text_tags, (h, w))
|
text_polys, text_tags, (h, w))
|
||||||
|
|
|
@ -29,20 +29,20 @@ class PGDataSet(Dataset):
|
||||||
dataset_config = config[mode]['dataset']
|
dataset_config = config[mode]['dataset']
|
||||||
loader_config = config[mode]['loader']
|
loader_config = config[mode]['loader']
|
||||||
|
|
||||||
|
self.delimiter = dataset_config.get('delimiter', '\t')
|
||||||
label_file_list = dataset_config.pop('label_file_list')
|
label_file_list = dataset_config.pop('label_file_list')
|
||||||
data_source_num = len(label_file_list)
|
data_source_num = len(label_file_list)
|
||||||
ratio_list = dataset_config.get("ratio_list", [1.0])
|
ratio_list = dataset_config.get("ratio_list", [1.0])
|
||||||
if isinstance(ratio_list, (float, int)):
|
if isinstance(ratio_list, (float, int)):
|
||||||
ratio_list = [float(ratio_list)] * int(data_source_num)
|
ratio_list = [float(ratio_list)] * int(data_source_num)
|
||||||
self.data_format = dataset_config.get('data_format', 'icdar')
|
|
||||||
assert len(
|
assert len(
|
||||||
ratio_list
|
ratio_list
|
||||||
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
) == data_source_num, "The length of ratio_list should be the same as the file_list."
|
||||||
|
self.data_dir = dataset_config['data_dir']
|
||||||
self.do_shuffle = loader_config['shuffle']
|
self.do_shuffle = loader_config['shuffle']
|
||||||
|
|
||||||
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
logger.info("Initialize indexs of datasets:%s" % label_file_list)
|
||||||
self.data_lines = self.get_image_info_list(label_file_list, ratio_list,
|
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
|
||||||
self.data_format)
|
|
||||||
self.data_idx_order_list = list(range(len(self.data_lines)))
|
self.data_idx_order_list = list(range(len(self.data_lines)))
|
||||||
if mode.lower() == "train":
|
if mode.lower() == "train":
|
||||||
self.shuffle_data_random()
|
self.shuffle_data_random()
|
||||||
|
@ -55,108 +55,40 @@ class PGDataSet(Dataset):
|
||||||
random.shuffle(self.data_lines)
|
random.shuffle(self.data_lines)
|
||||||
return
|
return
|
||||||
|
|
||||||
def extract_polys(self, poly_txt_path):
|
def get_image_info_list(self, file_list, ratio_list):
|
||||||
"""
|
|
||||||
Read text_polys, txt_tags, txts from give txt file.
|
|
||||||
"""
|
|
||||||
text_polys, txt_tags, txts = [], [], []
|
|
||||||
with open(poly_txt_path) as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
poly_str, txt = line.strip().split('\t')
|
|
||||||
poly = list(map(float, poly_str.split(',')))
|
|
||||||
text_polys.append(
|
|
||||||
np.array(
|
|
||||||
poly, dtype=np.float32).reshape(-1, 2))
|
|
||||||
txts.append(txt)
|
|
||||||
txt_tags.append(txt == '###')
|
|
||||||
|
|
||||||
return np.array(list(map(np.array, text_polys))), \
|
|
||||||
np.array(txt_tags, dtype=np.bool), txts
|
|
||||||
|
|
||||||
def extract_info_textnet(self, im_fn, img_dir=''):
|
|
||||||
"""
|
|
||||||
Extract information from line in textnet format.
|
|
||||||
"""
|
|
||||||
info_list = im_fn.split('\t')
|
|
||||||
img_path = ''
|
|
||||||
for ext in [
|
|
||||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
|
|
||||||
]:
|
|
||||||
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
|
|
||||||
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
|
|
||||||
break
|
|
||||||
|
|
||||||
if img_path == '':
|
|
||||||
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
|
|
||||||
info_list[0], img_dir))
|
|
||||||
|
|
||||||
nBox = (len(info_list) - 1) // 9
|
|
||||||
wordBBs, txts, txt_tags = [], [], []
|
|
||||||
for n in range(0, nBox):
|
|
||||||
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
|
|
||||||
txt = info_list[(n + 1) * 9]
|
|
||||||
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
|
|
||||||
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
|
|
||||||
txts.append(txt)
|
|
||||||
if txt == '###':
|
|
||||||
txt_tags.append(True)
|
|
||||||
else:
|
|
||||||
txt_tags.append(False)
|
|
||||||
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
|
|
||||||
|
|
||||||
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
|
|
||||||
if isinstance(file_list, str):
|
if isinstance(file_list, str):
|
||||||
file_list = [file_list]
|
file_list = [file_list]
|
||||||
data_lines = []
|
data_lines = []
|
||||||
for idx, data_source in enumerate(file_list):
|
for idx, file in enumerate(file_list):
|
||||||
image_files = []
|
with open(file, "rb") as f:
|
||||||
if data_format == 'icdar':
|
lines = f.readlines()
|
||||||
image_files = [(data_source, x) for x in
|
if self.mode == "train" or ratio_list[idx] < 1.0:
|
||||||
os.listdir(os.path.join(data_source, 'rgb'))
|
|
||||||
if x.split('.')[-1] in [
|
|
||||||
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
|
|
||||||
'tiff', 'gif', 'JPG'
|
|
||||||
]]
|
|
||||||
elif data_format == 'textnet':
|
|
||||||
with open(data_source) as f:
|
|
||||||
image_files = [(data_source, x.strip())
|
|
||||||
for x in f.readlines()]
|
|
||||||
else:
|
|
||||||
print("Unrecognized data format...")
|
|
||||||
exit(-1)
|
|
||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
image_files = random.sample(
|
lines = random.sample(lines,
|
||||||
image_files, round(len(image_files) * ratio_list[idx]))
|
round(len(lines) * ratio_list[idx]))
|
||||||
data_lines.extend(image_files)
|
data_lines.extend(lines)
|
||||||
return data_lines
|
return data_lines
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
file_idx = self.data_idx_order_list[idx]
|
file_idx = self.data_idx_order_list[idx]
|
||||||
data_path, data_line = self.data_lines[file_idx]
|
data_line = self.data_lines[file_idx]
|
||||||
try:
|
try:
|
||||||
if self.data_format == 'icdar':
|
data_line = data_line.decode('utf-8')
|
||||||
im_path = os.path.join(data_path, 'rgb', data_line)
|
substr = data_line.strip("\n").split(self.delimiter)
|
||||||
poly_path = os.path.join(data_path, 'poly',
|
file_name = substr[0]
|
||||||
data_line.split('.')[0] + '.txt')
|
label = substr[1]
|
||||||
text_polys, text_tags, text_strs = self.extract_polys(poly_path)
|
img_path = os.path.join(self.data_dir, file_name)
|
||||||
|
if self.mode.lower() == 'eval':
|
||||||
|
img_id = int(data_line.split(".")[0][7:])
|
||||||
else:
|
else:
|
||||||
image_dir = os.path.join(os.path.dirname(data_path), 'image')
|
img_id = 0
|
||||||
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet(
|
data = {'img_path': img_path, 'label': label, 'img_id': img_id}
|
||||||
data_line, image_dir)
|
if not os.path.exists(img_path):
|
||||||
img_id = int(data_line.split(".")[0][3:])
|
raise Exception("{} does not exist!".format(img_path))
|
||||||
|
|
||||||
data = {
|
|
||||||
'img_path': im_path,
|
|
||||||
'polys': text_polys,
|
|
||||||
'tags': text_tags,
|
|
||||||
'strs': text_strs,
|
|
||||||
'img_id': img_id
|
|
||||||
}
|
|
||||||
with open(data['img_path'], 'rb') as f:
|
with open(data['img_path'], 'rb') as f:
|
||||||
img = f.read()
|
img = f.read()
|
||||||
data['image'] = img
|
data['image'] = img
|
||||||
outs = transform(data, self.ops)
|
outs = transform(data, self.ops)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(
|
self.logger.error(
|
||||||
"When parsing line {}, error happened with msg: {}".format(
|
"When parsing line {}, error happened with msg: {}".format(
|
||||||
|
|
|
@ -35,11 +35,11 @@ class E2EMetric(object):
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def __call__(self, preds, batch, **kwargs):
|
def __call__(self, preds, batch, **kwargs):
|
||||||
img_id = batch[5][0]
|
img_id = batch[2][0]
|
||||||
e2e_info_list = [{
|
e2e_info_list = [{
|
||||||
'points': det_polyon,
|
'points': det_polyon,
|
||||||
'text': pred_str
|
'texts': pred_str
|
||||||
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])]
|
} for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
|
||||||
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
|
||||||
self.results.append(result)
|
self.results.append(result)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
|
||||||
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
|
||||||
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
|
||||||
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
|
||||||
'ne', 'EN'
|
'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
|
||||||
]
|
]
|
||||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||||
support_character_type, character_type)
|
support_character_type, character_type)
|
||||||
|
|
|
@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
|
||||||
n = len(pred_dict)
|
n = len(pred_dict)
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
points = pred_dict[i]['points']
|
points = pred_dict[i]['points']
|
||||||
text = pred_dict[i]['text']
|
text = pred_dict[i]['texts']
|
||||||
point = ",".join(map(str, points.reshape(-1, )))
|
point = ",".join(map(str, points.reshape(-1, )))
|
||||||
det.append([point, text])
|
det.append([point, text])
|
||||||
return det
|
return det
|
||||||
|
|
|
@ -21,6 +21,7 @@ import math
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
|
from cv2.ximgproc import thinning as thin
|
||||||
from skimage.morphology._skeletonize import thin
|
from skimage.morphology._skeletonize import thin
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
|
||||||
src_w, src_h, self.valid_set)
|
src_w, src_h, self.valid_set)
|
||||||
data = {
|
data = {
|
||||||
'points': poly_list,
|
'points': poly_list,
|
||||||
'strs': keep_str_list,
|
'texts': keep_str_list,
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
|
||||||
exit(-1)
|
exit(-1)
|
||||||
data = {
|
data = {
|
||||||
'points': poly_list,
|
'points': poly_list,
|
||||||
'strs': keep_str_list,
|
'texts': keep_str_list,
|
||||||
}
|
}
|
||||||
return data
|
return data
|
||||||
|
|
|
@ -122,7 +122,7 @@ class TextE2E(object):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
post_result = self.postprocess_op(preds, shape_list)
|
post_result = self.postprocess_op(preds, shape_list)
|
||||||
points, strs = post_result['points'], post_result['strs']
|
points, strs = post_result['points'], post_result['texts']
|
||||||
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
|
||||||
elapse = time.time() - starttime
|
elapse = time.time() - starttime
|
||||||
return dt_boxes, strs, elapse
|
return dt_boxes, strs, elapse
|
||||||
|
|
|
@ -103,7 +103,7 @@ def main():
|
||||||
images = paddle.to_tensor(images)
|
images = paddle.to_tensor(images)
|
||||||
preds = model(images)
|
preds = model(images)
|
||||||
post_result = post_process_class(preds, shape_list)
|
post_result = post_process_class(preds, shape_list)
|
||||||
points, strs = post_result['points'], post_result['strs']
|
points, strs = post_result['points'], post_result['texts']
|
||||||
# write resule
|
# write resule
|
||||||
dt_boxes_json = []
|
dt_boxes_json = []
|
||||||
for poly, str in zip(points, strs):
|
for poly, str in zip(points, strs):
|
||||||
|
|
Loading…
Reference in New Issue