Merge branch 'update_requirements' of https://github.com/WenmuZhou/PaddleOCR into update_requirements

This commit is contained in:
WenmuZhou 2021-04-17 19:29:49 +08:00
commit 7b7c8f3bb7
25 changed files with 4981 additions and 5034 deletions

View File

@ -147,6 +147,7 @@ class MainWindow(QMainWindow, WindowMixin):
self.itemsToShapesbox = {} self.itemsToShapesbox = {}
self.shapesToItemsbox = {} self.shapesToItemsbox = {}
self.prevLabelText = getStr('tempLabel') self.prevLabelText = getStr('tempLabel')
self.noLabelText = getStr('nullLabel')
self.model = 'paddle' self.model = 'paddle'
self.PPreader = None self.PPreader = None
self.autoSaveNum = 5 self.autoSaveNum = 5
@ -1020,7 +1021,7 @@ class MainWindow(QMainWindow, WindowMixin):
item.setText(str([(int(p.x()), int(p.y())) for p in shape.points])) item.setText(str([(int(p.x()), int(p.y())) for p in shape.points]))
self.updateComboBox() self.updateComboBox()
def updateComboBox(self): # TODO貌似没用 def updateComboBox(self):
# Get the unique labels and add them to the Combobox. # Get the unique labels and add them to the Combobox.
itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())] itemsTextList = [str(self.labelList.item(i).text()) for i in range(self.labelList.count())]
@ -1040,7 +1041,7 @@ class MainWindow(QMainWindow, WindowMixin):
return dict(label=s.label, # str return dict(label=s.label, # str
line_color=s.line_color.getRgb(), line_color=s.line_color.getRgb(),
fill_color=s.fill_color.getRgb(), fill_color=s.fill_color.getRgb(),
points=[(p.x(), p.y()) for p in s.points], # QPonitF points=[(int(p.x()), int(p.y())) for p in s.points], # QPonitF
# add chris # add chris
difficult=s.difficult) # bool difficult=s.difficult) # bool
@ -1069,7 +1070,7 @@ class MainWindow(QMainWindow, WindowMixin):
# print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath)) # print('Image:{0} -> Annotation:{1}'.format(self.filePath, annotationFilePath))
return True return True
except: except:
self.errorMessage(u'Error saving label data') self.errorMessage(u'Error saving label data', u'Error saving label data')
return False return False
def copySelectedShape(self): def copySelectedShape(self):
@ -1802,7 +1803,11 @@ class MainWindow(QMainWindow, WindowMixin):
result.insert(0, box) result.insert(0, box)
print('result in reRec is ', result) print('result in reRec is ', result)
self.result_dic.append(result) self.result_dic.append(result)
if result[1][0] == shape.label: else:
print('Can not recognise the box')
self.result_dic.append([box,(self.noLabelText,0)])
if self.noLabelText == shape.label or result[1][0] == shape.label:
print('label no change') print('label no change')
else: else:
rec_flag += 1 rec_flag += 1
@ -1836,9 +1841,14 @@ class MainWindow(QMainWindow, WindowMixin):
print('label no change') print('label no change')
else: else:
shape.label = result[1][0] shape.label = result[1][0]
else:
print('Can not recognise the box')
if self.noLabelText == shape.label:
print('label no change')
else:
shape.label = self.noLabelText
self.singleLabel(shape) self.singleLabel(shape)
self.setDirty() self.setDirty()
print(box)
def autolcm(self): def autolcm(self):
vbox = QVBoxLayout() vbox = QVBoxLayout()

View File

@ -45,7 +45,7 @@ class Canvas(QWidget):
CREATE, EDIT = list(range(2)) CREATE, EDIT = list(range(2))
_fill_drawing = False # draw shadows _fill_drawing = False # draw shadows
epsilon = 11.0 epsilon = 5.0
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Canvas, self).__init__(*args, **kwargs) super(Canvas, self).__init__(*args, **kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -87,6 +87,7 @@ creatPolygon=四点标注
drawSquares=正方形标注 drawSquares=正方形标注
saveRec=保存识别结果 saveRec=保存识别结果
tempLabel=待识别 tempLabel=待识别
nullLabel=无法识别
steps=操作步骤 steps=操作步骤
choseModelLg=选择模型语言 choseModelLg=选择模型语言
cancel=取消 cancel=取消

View File

@ -77,7 +77,7 @@ IR=Image Resize
autoRecognition=Auto Recognition autoRecognition=Auto Recognition
reRecognition=Re-recognition reRecognition=Re-recognition
mfile=File mfile=File
medit=Eidt medit=Edit
mview=View mview=View
mhelp=Help mhelp=Help
iconList=Icon List iconList=Icon List
@ -87,6 +87,7 @@ creatPolygon=Create Quadrilateral
drawSquares=Draw Squares drawSquares=Draw Squares
saveRec=Save Recognition Result saveRec=Save Recognition Result
tempLabel=TEMPORARY tempLabel=TEMPORARY
nullLabel=NULL
steps=Steps steps=Steps
choseModelLg=Choose Model Language choseModelLg=Choose Model Language
cancel=Cancel cancel=Cancel

View File

@ -32,7 +32,8 @@ PaddleOCR supports both dynamic graph and static graph programming paradigm
<div align="center"> <div align="center">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800"> <img src="doc/imgs_results/ch_ppocr_mobile_v2.0/test_add_91.jpg" width="800">
<img src="doc/imgs_results/ch_ppocr_mobile_v2.0/00018069.jpg" width="800"> <img src="doc/imgs_results/multi_lang/img_01.jpg" width="800">
<img src="doc/imgs_results/multi_lang/img_02.jpg" width="800">
</div> </div>
The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md). The above pictures are the visualizations of the general ppocr_server model. For more effect pictures, please see [More visualizations](./doc/doc_en/visualization_en.md).

View File

@ -62,20 +62,21 @@ PostProcess:
mode: fast # fast or slow two ways mode: fast # fast or slow two ways
Metric: Metric:
name: E2EMetric name: E2EMetric
gt_mat_dir: # the dir of gt_mat gt_mat_dir: ./train_data/total_text/gt # the dir of gt_mat
character_dict_path: ppocr/utils/ic15_dict.txt character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e main_indicator: f_score_e2e
Train: Train:
dataset: dataset:
name: PGDataSet name: PGDataSet
label_file_list: [.././train_data/total_text/train/] data_dir: ./train_data/total_text/train
label_file_list: [./train_data/total_text/train/]
ratio_list: [1.0] ratio_list: [1.0]
data_format: icdar #two data format: icdar/textnet
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- E2ELabelEncode:
- PGProcessTrain: - PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24 min_crop_size: 24
@ -92,13 +93,12 @@ Train:
Eval: Eval:
dataset: dataset:
name: PGDataSet name: PGDataSet
data_dir: ./train_data/ data_dir: ./train_data/total_text/test
label_file_list: [./train_data/total_text/test/] label_file_list: [./train_data/total_text/test/]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
channel_first: False channel_first: False
- E2ELabelEncode:
- E2EResizeForTest: - E2EResizeForTest:
max_side_len: 768 max_side_len: 768
- NormalizeImage: - NormalizeImage:
@ -108,7 +108,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] keep_keys: [ 'image', 'shape', 'img_id']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False

View File

@ -118,7 +118,6 @@ class ArgsParser(ArgumentParser):
return config return config
def _set_language(self, type): def _set_language(self, type):
print("type:", type)
lang = type[0] lang = type[0]
assert (type), "please use -l or --language to choose language type" assert (type), "please use -l or --language to choose language type"
assert( assert(

View File

@ -113,7 +113,7 @@ python3 generate_multi_language_configs.py -l it \
| cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) | | cyrillic_mobile_v2.0_rec | 斯拉夫字母 | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) | | devanagari_mobile_v2.0_rec | 梵文字母 | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99) 更多支持语种请参考: [多语言模型](./multi_languages.md)
<a name="文本方向分类模型"></a> <a name="文本方向分类模型"></a>

View File

@ -134,7 +134,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --rec false
<a name="python_脚本运行"></a> <a name="python_脚本运行"></a>
### 2.2 python 脚本运行 ### 2.2 python 脚本运行
ppocr 也支持在python脚本中运行便于嵌入到您自己的代码中 ppocr 也支持在python脚本中运行便于嵌入到您自己的代码中
* 整图预测(检测+识别) * 整图预测(检测+识别)
@ -155,7 +155,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
@ -240,7 +240,7 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|德文|german|german| |德文|german|german|
|日文|japan|japan| |日文|japan|japan|
|韩文|korean|korean| |韩文|korean|korean|
|中文繁体|chinese traditional |ch_tra| |中文繁体|chinese traditional |chinese_cht|
|意大利文| Italian |it| |意大利文| Italian |it|
|西班牙文|Spanish |es| |西班牙文|Spanish |es|
|葡萄牙文| Portuguese|pt| |葡萄牙文| Portuguese|pt|
@ -259,7 +259,6 @@ ppocr 支持使用自己的数据进行自定义训练或finetune, 其中识别
|乌克兰文|Ukranian|uk| |乌克兰文|Ukranian|uk|
|白俄罗斯文|Belarusian|be| |白俄罗斯文|Belarusian|be|
|泰卢固文|Telugu |te| |泰卢固文|Telugu |te|
|卡纳达文|Kannada |kn|
|泰米尔文|Tamil |ta| |泰米尔文|Tamil |ta|
|南非荷兰文 |Afrikaans |af| |南非荷兰文 |Afrikaans |af|
|阿塞拜疆文 |Azerbaijani |az| |阿塞拜疆文 |Azerbaijani |az|

View File

@ -111,7 +111,7 @@ python3 generate_multi_language_configs.py -l it \
| cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) | | cyrillic_mobile_v2.0_rec | Lightweight model for cyrillic recognition | [rec_cyrillic_lite_train.yml](../../configs/rec/multi_language/rec_cyrillic_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_train.tar) |
| devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) | | devanagari_mobile_v2.0_rec | Lightweight model for devanagari recognition | [rec_devanagari_lite_train.yml](../../configs/rec/multi_language/rec_devanagari_lite_train.yml) |2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_train.tar) |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations) For more supported languages, please refer to : [Multi-language model](./multi_languages_en.md)
<a name="Angle"></a> <a name="Angle"></a>

View File

@ -153,7 +153,7 @@ image = Image.open(img_path).convert('RGB')
boxes = [line[0] for line in result] boxes = [line[0] for line in result]
txts = [line[1][0] for line in result] txts = [line[1][0] for line in result]
scores = [line[1][1] for line in result] scores = [line[1][1] for line in result]
im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/korean.ttf') im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/fonts/korean.ttf')
im_show = Image.fromarray(im_show) im_show = Image.fromarray(im_show)
im_show.save('result.jpg') im_show.save('result.jpg')
``` ```
@ -232,7 +232,7 @@ For functions such as data annotation, you can read the complete [Document Tutor
|german|german| |german|german|
|japan|japan| |japan|japan|
|korean|korean| |korean|korean|
|chinese traditional |ch_tra| |chinese traditional |chinese_cht|
| Italian |it| | Italian |it|
|Spanish |es| |Spanish |es|
| Portuguese|pt| | Portuguese|pt|
@ -251,7 +251,6 @@ For functions such as data annotation, you can read the complete [Document Tutor
|Ukranian|uk| |Ukranian|uk|
|Belarusian|be| |Belarusian|be|
|Telugu |te| |Telugu |te|
|Kannada |kn|
|Tamil |ta| |Tamil |ta|
|Afrikaans |af| |Afrikaans |af|
|Azerbaijani |az| |Azerbaijani |az|

Binary file not shown.

After

Width:  |  Height:  |  Size: 107 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 231 KiB

View File

@ -30,6 +30,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
@ -117,7 +118,7 @@ model_urls = {
} }
SUPPORT_DET_MODEL = ['DB'] SUPPORT_DET_MODEL = ['DB']
VERSION = 2.1 VERSION = '2.1'
SUPPORT_REC_MODEL = ['CRNN'] SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/") BASE_DIR = os.path.expanduser("~/.paddleocr/")
@ -315,14 +316,13 @@ class PaddleOCR(predict_system.TextSystem):
# init model dir # init model dir
if postprocess_params.det_model_dir is None: if postprocess_params.det_model_dir is None:
postprocess_params.det_model_dir = os.path.join( postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
BASE_DIR, '{}/det/{}'.format(VERSION, det_lang)) 'det', det_lang)
if postprocess_params.rec_model_dir is None: if postprocess_params.rec_model_dir is None:
postprocess_params.rec_model_dir = os.path.join( postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
BASE_DIR, '{}/rec/{}'.format(VERSION, lang)) 'rec', lang)
if postprocess_params.cls_model_dir is None: if postprocess_params.cls_model_dir is None:
postprocess_params.cls_model_dir = os.path.join( postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
BASE_DIR, '{}/cls'.format(VERSION))
print(postprocess_params) print(postprocess_params)
# download model # download model
maybe_download(postprocess_params.det_model_dir, maybe_download(postprocess_params.det_model_dir,

View File

@ -96,7 +96,7 @@ class BaseRecLabelEncode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
'mr', 'ne' 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)
@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character return dict_character
class E2ELabelEncode(BaseRecLabelEncode): class E2ELabelEncode(object):
def __init__(self, def __init__(self, **kwargs):
max_text_length, pass
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data): def __call__(self, data):
texts = data['strs'] import json
temp_texts = [] label = data['label']
for text in texts: label = json.loads(label)
text = text.lower() nBox = len(label)
text = self.encode(text) boxes, txts, txt_tags = [], [], []
if text is None: for bno in range(0, nBox):
return None box = label[bno]['points']
text = text + [self.pad_num] * (self.max_text_len - len(text)) txt = label[bno]['transcription']
temp_texts.append(text) boxes.append(box)
data['strs'] = np.array(temp_texts) txts.append(txt)
if txt in ['*', '###']:
txt_tags.append(True)
else:
txt_tags.append(False)
boxes = np.array(boxes, dtype=np.float32)
txt_tags = np.array(txt_tags, dtype=np.bool)
data['polys'] = boxes
data['texts'] = txts
data['ignore_tags'] = txt_tags
return data return data

View File

@ -88,7 +88,7 @@ class PGProcessTrain(object):
return min_area_quad return min_area_quad
def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): def check_and_validate_polys(self, polys, tags, im_size):
""" """
check so that the text poly is in the same direction, check so that the text poly is in the same direction,
and also filter some invalid polygons and also filter some invalid polygons
@ -96,7 +96,7 @@ class PGProcessTrain(object):
:param tags: :param tags:
:return: :return:
""" """
(h, w) = xxx_todo_changeme (h, w) = im_size
if polys.shape[0] == 0: if polys.shape[0] == 0:
return polys, np.array([]), np.array([]) return polys, np.array([]), np.array([])
polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
@ -750,8 +750,8 @@ class PGProcessTrain(object):
input_size = 512 input_size = 512
im = data['image'] im = data['image']
text_polys = data['polys'] text_polys = data['polys']
text_tags = data['tags'] text_tags = data['ignore_tags']
text_strs = data['strs'] text_strs = data['texts']
h, w, _ = im.shape h, w, _ = im.shape
text_polys, text_tags, hv_tags = self.check_and_validate_polys( text_polys, text_tags, hv_tags = self.check_and_validate_polys(
text_polys, text_tags, (h, w)) text_polys, text_tags, (h, w))

View File

@ -29,20 +29,20 @@ class PGDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
self.delimiter = dataset_config.get('delimiter', '\t')
label_file_list = dataset_config.pop('label_file_list') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list) data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0]) ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)): if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num) ratio_list = [float(ratio_list)] * int(data_source_num)
self.data_format = dataset_config.get('data_format', 'icdar')
assert len( assert len(
ratio_list ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list." ) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
logger.info("Initialize indexs of datasets:%s" % label_file_list) logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list, self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_format)
self.data_idx_order_list = list(range(len(self.data_lines))) self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train": if mode.lower() == "train":
self.shuffle_data_random() self.shuffle_data_random()
@ -55,108 +55,40 @@ class PGDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def extract_polys(self, poly_txt_path): def get_image_info_list(self, file_list, ratio_list):
"""
Read text_polys, txt_tags, txts from give txt file.
"""
text_polys, txt_tags, txts = [], [], []
with open(poly_txt_path) as f:
for line in f.readlines():
poly_str, txt = line.strip().split('\t')
poly = list(map(float, poly_str.split(',')))
text_polys.append(
np.array(
poly, dtype=np.float32).reshape(-1, 2))
txts.append(txt)
txt_tags.append(txt == '###')
return np.array(list(map(np.array, text_polys))), \
np.array(txt_tags, dtype=np.bool), txts
def extract_info_textnet(self, im_fn, img_dir=''):
"""
Extract information from line in textnet format.
"""
info_list = im_fn.split('\t')
img_path = ''
for ext in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG'
]:
if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)):
img_path = os.path.join(img_dir, info_list[0] + "." + ext)
break
if img_path == '':
print('Image {0} NOT found in {1}, and it will be ignored.'.format(
info_list[0], img_dir))
nBox = (len(info_list) - 1) // 9
wordBBs, txts, txt_tags = [], [], []
for n in range(0, nBox):
wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9]))
txt = info_list[(n + 1) * 9]
wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]],
[wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]])
txts.append(txt)
if txt == '###':
txt_tags.append(True)
else:
txt_tags.append(False)
return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts
def get_image_info_list(self, file_list, ratio_list, data_format='textnet'):
if isinstance(file_list, str): if isinstance(file_list, str):
file_list = [file_list] file_list = [file_list]
data_lines = [] data_lines = []
for idx, data_source in enumerate(file_list): for idx, file in enumerate(file_list):
image_files = [] with open(file, "rb") as f:
if data_format == 'icdar': lines = f.readlines()
image_files = [(data_source, x) for x in if self.mode == "train" or ratio_list[idx] < 1.0:
os.listdir(os.path.join(data_source, 'rgb'))
if x.split('.')[-1] in [
'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif',
'tiff', 'gif', 'JPG'
]]
elif data_format == 'textnet':
with open(data_source) as f:
image_files = [(data_source, x.strip())
for x in f.readlines()]
else:
print("Unrecognized data format...")
exit(-1)
random.seed(self.seed) random.seed(self.seed)
image_files = random.sample( lines = random.sample(lines,
image_files, round(len(image_files) * ratio_list[idx])) round(len(lines) * ratio_list[idx]))
data_lines.extend(image_files) data_lines.extend(lines)
return data_lines return data_lines
def __getitem__(self, idx): def __getitem__(self, idx):
file_idx = self.data_idx_order_list[idx] file_idx = self.data_idx_order_list[idx]
data_path, data_line = self.data_lines[file_idx] data_line = self.data_lines[file_idx]
try: try:
if self.data_format == 'icdar': data_line = data_line.decode('utf-8')
im_path = os.path.join(data_path, 'rgb', data_line) substr = data_line.strip("\n").split(self.delimiter)
poly_path = os.path.join(data_path, 'poly', file_name = substr[0]
data_line.split('.')[0] + '.txt') label = substr[1]
text_polys, text_tags, text_strs = self.extract_polys(poly_path) img_path = os.path.join(self.data_dir, file_name)
if self.mode.lower() == 'eval':
img_id = int(data_line.split(".")[0][7:])
else: else:
image_dir = os.path.join(os.path.dirname(data_path), 'image') img_id = 0
im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( data = {'img_path': img_path, 'label': label, 'img_id': img_id}
data_line, image_dir) if not os.path.exists(img_path):
img_id = int(data_line.split(".")[0][3:]) raise Exception("{} does not exist!".format(img_path))
data = {
'img_path': im_path,
'polys': text_polys,
'tags': text_tags,
'strs': text_strs,
'img_id': img_id
}
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
img = f.read() img = f.read()
data['image'] = img data['image'] = img
outs = transform(data, self.ops) outs = transform(data, self.ops)
except Exception as e: except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(

View File

@ -35,11 +35,11 @@ class E2EMetric(object):
self.reset() self.reset()
def __call__(self, preds, batch, **kwargs): def __call__(self, preds, batch, **kwargs):
img_id = batch[5][0] img_id = batch[2][0]
e2e_info_list = [{ e2e_info_list = [{
'points': det_polyon, 'points': det_polyon,
'text': pred_str 'texts': pred_str
} for det_polyon, pred_str in zip(preds['points'], preds['strs'])] } for det_polyon, pred_str in zip(preds['points'], preds['texts'])]
result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) result = get_socre(self.gt_mat_dir, img_id, e2e_info_list)
self.results.append(result) self.results.append(result)

View File

@ -28,7 +28,7 @@ class BaseRecLabelDecode(object):
'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean', 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
'ne', 'EN' 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
] ]
assert character_type in support_character_type, "Only {} are supported now but get {}".format( assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type) support_character_type, character_type)

View File

@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict):
n = len(pred_dict) n = len(pred_dict)
for i in range(n): for i in range(n):
points = pred_dict[i]['points'] points = pred_dict[i]['points']
text = pred_dict[i]['text'] text = pred_dict[i]['texts']
point = ",".join(map(str, points.reshape(-1, ))) point = ",".join(map(str, points.reshape(-1, )))
det.append([point, text]) det.append([point, text])
return det return det

View File

@ -21,6 +21,7 @@ import math
import numpy as np import numpy as np
from itertools import groupby from itertools import groupby
from cv2.ximgproc import thinning as thin
from skimage.morphology._skeletonize import thin from skimage.morphology._skeletonize import thin

View File

@ -64,7 +64,7 @@ class PGNet_PostProcess(object):
src_w, src_h, self.valid_set) src_w, src_h, self.valid_set)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data
@ -176,6 +176,6 @@ class PGNet_PostProcess(object):
exit(-1) exit(-1)
data = { data = {
'points': poly_list, 'points': poly_list,
'strs': keep_str_list, 'texts': keep_str_list,
} }
return data return data

View File

@ -122,7 +122,7 @@ class TextE2E(object):
else: else:
raise NotImplementedError raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list) post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape) dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime elapse = time.time() - starttime
return dt_boxes, strs, elapse return dt_boxes, strs, elapse

View File

@ -103,7 +103,7 @@ def main():
images = paddle.to_tensor(images) images = paddle.to_tensor(images)
preds = model(images) preds = model(images)
post_result = post_process_class(preds, shape_list) post_result = post_process_class(preds, shape_list)
points, strs = post_result['points'], post_result['strs'] points, strs = post_result['points'], post_result['texts']
# write resule # write resule
dt_boxes_json = [] dt_boxes_json = []
for poly, str in zip(points, strs): for poly, str in zip(points, strs):