Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph_rc
|
@ -24,6 +24,7 @@ import sys
|
|||
from functools import partial
|
||||
from collections import defaultdict
|
||||
import json
|
||||
import cv2
|
||||
|
||||
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
@ -1242,10 +1243,13 @@ class MainWindow(QMainWindow, WindowMixin):
|
|||
# if unicodeFilePath in self.mImgList:
|
||||
|
||||
if unicodeFilePath and os.path.exists(unicodeFilePath):
|
||||
self.imageData = read(unicodeFilePath, None)
|
||||
self.canvas.verified = False
|
||||
|
||||
image = QImage.fromData(self.imageData)
|
||||
cvimg = cv2.imdecode(np.fromfile(unicodeFilePath, dtype=np.uint8), 1)
|
||||
height, width, depth = cvimg.shape
|
||||
cvimg = cv2.cvtColor(cvimg, cv2.COLOR_BGR2RGB)
|
||||
image = QImage(cvimg.data, width, height, width * depth, QImage.Format_RGB888)
|
||||
|
||||
if image.isNull():
|
||||
self.errorMessage(u'Error opening file',
|
||||
u"<p>Make sure <i>%s</i> is a valid image file." % unicodeFilePath)
|
||||
|
|
|
@ -7,6 +7,8 @@ except ImportError:
|
|||
from PyQt4.QtCore import *
|
||||
|
||||
import json
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from libs.utils import newIcon
|
||||
|
||||
|
@ -34,11 +36,16 @@ class Worker(QThread):
|
|||
if self.handle == 0:
|
||||
self.listValue.emit(Imgpath)
|
||||
if self.model == 'paddle':
|
||||
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)
|
||||
h, w, _ = cv2.imdecode(np.fromfile(Imgpath, dtype=np.uint8), 1).shape
|
||||
if h > 32 and w > 32:
|
||||
self.result_dic = self.ocr.ocr(Imgpath, cls=True, det=True)
|
||||
else:
|
||||
print('The size of', Imgpath, 'is too small to be recognised')
|
||||
self.result_dic = None
|
||||
|
||||
# 结果保存
|
||||
if self.result_dic is None or len(self.result_dic) == 0:
|
||||
print('Can not recognise file is : ', Imgpath)
|
||||
print('Can not recognise file', Imgpath)
|
||||
pass
|
||||
else:
|
||||
strs = ''
|
||||
|
|
|
@ -0,0 +1,152 @@
|
|||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import yaml
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
import os.path
|
||||
import logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
support_list = {
|
||||
'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
|
||||
'ta':'tamil', 'ug':'uyghur', 'fa':'persian', 'ur':'urdu', 'rs':'serbian latin',
|
||||
'oc':'occitan', 'rsc':'serbian cyrillic', 'bg':'bulgarian', 'uk':'ukranian', 'be':'belarusian',
|
||||
'te':'telugu', 'ka':'kannada', 'chinese_cht':'chinese tradition','hi':'hindi','mr':'marathi',
|
||||
'ne':'nepali',
|
||||
}
|
||||
assert(
|
||||
os.path.isfile("./rec_multi_language_lite_train.yml")
|
||||
),"Loss basic configuration file rec_multi_language_lite_train.yml.\
|
||||
You can download it from \
|
||||
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
|
||||
|
||||
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
|
||||
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
formatter_class=RawDescriptionHelpFormatter)
|
||||
self.add_argument(
|
||||
"-o", "--opt", nargs='+', help="set configuration options")
|
||||
self.add_argument(
|
||||
"-l", "--language", nargs='+', help="set language type, support {}".format(support_list))
|
||||
self.add_argument(
|
||||
"--train",type=str,help="you can use this command to change the train dataset default path")
|
||||
self.add_argument(
|
||||
"--val",type=str,help="you can use this command to change the eval dataset default path")
|
||||
self.add_argument(
|
||||
"--dict",type=str,help="you can use this command to change the dictionary default path")
|
||||
self.add_argument(
|
||||
"--data_dir",type=str,help="you can use this command to change the dataset default root path")
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
args.opt = self._parse_opt(args.opt)
|
||||
args.language = self._set_language(args.language)
|
||||
return args
|
||||
|
||||
def _parse_opt(self, opts):
|
||||
config = {}
|
||||
if not opts:
|
||||
return config
|
||||
for s in opts:
|
||||
s = s.strip()
|
||||
k, v = s.split('=')
|
||||
config[k] = yaml.load(v, Loader=yaml.Loader)
|
||||
return config
|
||||
|
||||
def _set_language(self, type):
|
||||
assert(type),"please use -l or --language to choose language type"
|
||||
assert(
|
||||
type[0] in support_list.keys()
|
||||
),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
|
||||
"please check your running command".format(support_list, type)
|
||||
global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
|
||||
global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
|
||||
global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
|
||||
global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
|
||||
global_config['Global']['character_type'] = type[0]
|
||||
assert(
|
||||
os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path']))
|
||||
),"Loss default dictionary file {}_dict.txt.You can download it from \
|
||||
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0])
|
||||
return type[0]
|
||||
|
||||
|
||||
def merge_config(config):
|
||||
"""
|
||||
Merge config into global config.
|
||||
Args:
|
||||
config (dict): Config to be merged.
|
||||
Returns: global config
|
||||
"""
|
||||
for key, value in config.items():
|
||||
if "." not in key:
|
||||
if isinstance(value, dict) and key in global_config:
|
||||
global_config[key].update(value)
|
||||
else:
|
||||
global_config[key] = value
|
||||
else:
|
||||
sub_keys = key.split('.')
|
||||
assert (
|
||||
sub_keys[0] in global_config
|
||||
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
||||
global_config.keys(), sub_keys[0])
|
||||
cur = global_config[sub_keys[0]]
|
||||
for idx, sub_key in enumerate(sub_keys[1:]):
|
||||
if idx == len(sub_keys) - 2:
|
||||
cur[sub_key] = value
|
||||
else:
|
||||
cur = cur[sub_key]
|
||||
|
||||
def loss_file(path):
|
||||
assert(
|
||||
os.path.exists(path)
|
||||
),"There is no such file:{},Please do not forget to put in the specified file".format(path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
FLAGS = ArgsParser().parse_args()
|
||||
merge_config(FLAGS.opt)
|
||||
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
|
||||
if os.path.isfile(save_file_path):
|
||||
os.remove(save_file_path)
|
||||
|
||||
if FLAGS.train:
|
||||
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
|
||||
train_label_path = os.path.join(project_path,FLAGS.train)
|
||||
loss_file(train_label_path)
|
||||
if FLAGS.val:
|
||||
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
|
||||
eval_label_path = os.path.join(project_path,FLAGS.val)
|
||||
loss_file(Eval_label_path)
|
||||
if FLAGS.dict:
|
||||
global_config['Global']['character_dict_path'] = FLAGS.dict
|
||||
dict_path = os.path.join(project_path,FLAGS.dict)
|
||||
loss_file(dict_path)
|
||||
if FLAGS.data_dir:
|
||||
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
|
||||
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
|
||||
data_dir = os.path.join(project_path,FLAGS.data_dir)
|
||||
loss_file(data_dir)
|
||||
|
||||
with open(save_file_path, 'w') as f:
|
||||
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
|
||||
logging.info("Project path is :{}".format(project_path))
|
||||
logging.info("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
|
||||
logging.info("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
|
||||
logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
|
||||
logging.info("Dict path set to :{}".format(global_config['Global']['character_dict_path']))
|
||||
logging.info("Config file set to :configs/rec/multi_language/{}".format(save_file_path))
|
|
@ -0,0 +1,103 @@
|
|||
Global:
|
||||
use_gpu: True
|
||||
epoch_num: 500
|
||||
log_smooth_window: 20
|
||||
print_batch_step: 10
|
||||
save_model_dir: ./output/rec_multi_language_lite
|
||||
save_epoch_step: 3
|
||||
# evaluation is run every 5000 iterations after the 4000th iteration
|
||||
eval_batch_step: [0, 2000]
|
||||
# if pretrained_model is saved in static mode, load_static_weights must set to True
|
||||
cal_metric_during_train: True
|
||||
pretrained_model:
|
||||
checkpoints:
|
||||
save_inference_dir:
|
||||
use_visualdl: False
|
||||
infer_img:
|
||||
# for data or label process
|
||||
character_dict_path:
|
||||
# Set the language of training, if set, select the default dictionary file
|
||||
character_type:
|
||||
max_text_length: 25
|
||||
infer_mode: False
|
||||
use_space_char: True
|
||||
|
||||
|
||||
Optimizer:
|
||||
name: Adam
|
||||
beta1: 0.9
|
||||
beta2: 0.999
|
||||
lr:
|
||||
name: Cosine
|
||||
learning_rate: 0.001
|
||||
regularizer:
|
||||
name: 'L2'
|
||||
factor: 0.00001
|
||||
|
||||
Architecture:
|
||||
model_type: rec
|
||||
algorithm: CRNN
|
||||
Transform:
|
||||
Backbone:
|
||||
name: MobileNetV3
|
||||
scale: 0.5
|
||||
model_name: small
|
||||
small_stride: [1, 2, 2, 2]
|
||||
Neck:
|
||||
name: SequenceEncoder
|
||||
encoder_type: rnn
|
||||
hidden_size: 48
|
||||
Head:
|
||||
name: CTCHead
|
||||
fc_decay: 0.00001
|
||||
|
||||
Loss:
|
||||
name: CTCLoss
|
||||
|
||||
PostProcess:
|
||||
name: CTCLabelDecode
|
||||
|
||||
Metric:
|
||||
name: RecMetric
|
||||
main_indicator: acc
|
||||
|
||||
Train:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/
|
||||
label_file_list: ["./train_data/train_list.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- RecAug:
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: True
|
||||
batch_size_per_card: 256
|
||||
drop_last: True
|
||||
num_workers: 8
|
||||
|
||||
Eval:
|
||||
dataset:
|
||||
name: SimpleDataSet
|
||||
data_dir: train_data/
|
||||
label_file_list: ["./train_data/val_list.txt"]
|
||||
transforms:
|
||||
- DecodeImage: # load image
|
||||
img_mode: BGR
|
||||
channel_first: False
|
||||
- CTCLabelEncode: # Class handling label
|
||||
- RecResizeImg:
|
||||
image_shape: [3, 32, 320]
|
||||
- KeepKeys:
|
||||
keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
|
||||
loader:
|
||||
shuffle: False
|
||||
drop_last: False
|
||||
batch_size_per_card: 256
|
||||
num_workers: 8
|
|
@ -14,11 +14,10 @@ PaddleOCR开源的文本检测算法列表:
|
|||
- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
|
||||
|
||||
在ICDAR2015文本检测公开数据集上,算法效果如下:
|
||||
|
||||
|模型|骨干网络|precision|recall|Hmean|下载链接|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|78.24%|79.15%|78.69%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## OCR模型列表(V2.0,2020年12月12日更新)
|
||||
## OCR模型列表(V2.0,2021年1月20日更新)
|
||||
**说明** :2.0版模型和[1.1版模型](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_ch/models_list.md)的主要区别在于动态图训练vs.静态图训练,模型性能上无明显差距。
|
||||
|
||||
- [一、文本检测模型](#文本检测模型)
|
||||
|
@ -52,12 +52,70 @@ PaddleOCR提供的可下载模型包括`推理模型`、`训练模型`、`预训
|
|||
<a name="多语言识别模型"></a>
|
||||
#### 3. 多语言识别模型(更多语言持续更新中...)
|
||||
|
||||
**说明:** 新增的多语言模型的配置文件通过代码方式生成,您可以通过`--help`参数查看当前PaddleOCR支持生成哪些多语言的配置文件:
|
||||
```bash
|
||||
# 该代码需要在指定目录运行
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
python3 generate_multi_language_configs.py --help
|
||||
```
|
||||
下面以生成意大利语配置文件为例:
|
||||
##### 1. 生成意大利语配置文件测试现有模型
|
||||
|
||||
如果您仅仅想用配置文件测试PaddleOCR提供的多语言模型可以通过下面命令生成默认的配置文件,使用PaddleOCR提供的小语种字典进行预测。
|
||||
```bash
|
||||
# 该代码需要在指定目录运行
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
# 通过-l或者--language参数设置需要生成的语种的配置文件,该命令会将默认参数写入配置文件
|
||||
python3 generate_multi_language_configs.py -l it
|
||||
```
|
||||
##### 2. 生成意大利语配置文件训练自己的数据
|
||||
如果您想训练自己的小语种模型,可以准备好训练集文件、验证集文件、字典文件和训练数据路径,这里假设准备的意大利语的训练集、验证集、字典和训练数据路径为:
|
||||
- 训练集:{your/path/}PaddleOCR/train_data/train_list.txt
|
||||
- 验证集:{your/path/}PaddleOCR/train_data/val_list.txt
|
||||
- 使用PaddleOCR提供的默认字典:{your/path/}PaddleOCR/ppocr/utils/dict/it_dict.txt
|
||||
- 训练数据路径:{your/path/}PaddleOCR/train_data
|
||||
|
||||
使用以下命令生成配置文件:
|
||||
```bash
|
||||
# 该代码需要在指定目录运行
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
# -l或者--language字段是必须的
|
||||
# --train修改训练集,--val修改验证集,--data_dir修改数据集目录,-o修改对应默认参数
|
||||
# --dict命令改变字典路径,示例使用默认字典路径则该参数可不填
|
||||
python3 generate_multi_language_configs.py -l it \
|
||||
--train train_data/train_list.txt \
|
||||
--val train_data/val_list.txt \
|
||||
--data_dir train_data \
|
||||
-o Global.use_gpu=False
|
||||
```
|
||||
|
||||
|模型名称|模型简介|配置文件|推理模型大小|下载地址|
|
||||
| --- | --- | --- | --- | --- |
|
||||
| french_mobile_v2.0_rec |法文识别|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
||||
| german_mobile_v2.0_rec |德文识别|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
||||
| korean_mobile_v2.0_rec |韩文识别|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
||||
| japan_mobile_v2.0_rec |日文识别|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
||||
| it_mobile_v2.0_rec |意大利文识别|rec_it_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
||||
| xi_mobile_v2.0_rec |西班牙文识别|rec_xi_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_train.tar) |
|
||||
| pu_mobile_v2.0_rec |葡萄牙文识别|rec_pu_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_train.tar) |
|
||||
| ru_mobile_v2.0_rec |俄罗斯文识别|rec_ru_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_train.tar) |
|
||||
| ar_mobile_v2.0_rec |阿拉伯文识别|rec_ar_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_train.tar) |
|
||||
| hi_mobile_v2.0_rec |印地文识别|rec_hi_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_train.tar) |
|
||||
| chinese_cht_mobile_v2.0_rec |中文繁体识别|rec_chinese_cht_lite_train.yml|5.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) |
|
||||
| ug_mobile_v2.0_rec |维吾尔文识别|rec_ug_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_train.tar) |
|
||||
| fa_mobile_v2.0_rec |波斯文识别|rec_fa_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_train.tar) |
|
||||
| ur_mobile_v2.0_rec |乌尔都文识别|rec_ur_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_train.tar) |
|
||||
| rs_mobile_v2.0_rec |塞尔维亚文(latin)识别|rec_rs_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_train.tar) |
|
||||
| oc_mobile_v2.0_rec |欧西坦文识别|rec_oc_lite_train.yml|2.53M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_train.tar) |
|
||||
| mr_mobile_v2.0_rec |马拉地文识别|rec_mr_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_train.tar) |
|
||||
| ne_mobile_v2.0_rec |尼泊尔文识别|rec_ne_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_train.tar) |
|
||||
| rsc_mobile_v2.0_rec |塞尔维亚文(cyrillic)识别|rec_rsc_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_train.tar) |
|
||||
| bg_mobile_v2.0_rec |保加利亚文识别|rec_bg_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_train.tar) |
|
||||
| uk_mobile_v2.0_rec |乌克兰文识别|rec_uk_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_train.tar) |
|
||||
| be_mobile_v2.0_rec |白俄罗斯文识别|rec_be_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_train.tar) |
|
||||
| te_mobile_v2.0_rec |泰卢固文识别|rec_te_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) |
|
||||
| ka_mobile_v2.0_rec |卡纳达文识别|rec_ka_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) |
|
||||
| ta_mobile_v2.0_rec |泰米尔文识别|rec_ta_lite_train.yml|2.63M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) |
|
||||
|
||||
|
||||
<a name="文本方向分类模型"></a>
|
||||
|
|
|
@ -19,8 +19,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|
|||
|
||||
|Model|Backbone|precision|recall|Hmean|Download link|
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
|EAST|ResNet50_vd|88.76%|81.36%|84.90%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|78.24%|79.15%|78.69%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|EAST|ResNet50_vd|85.80%|86.71%|86.25%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)|
|
||||
|EAST|MobileNetV3|79.42%|80.64%|80.03%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_east_v2.0_train.tar)|
|
||||
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|
||||
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|
||||
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
## OCR model list(V2.0, updated on 2020.12.12)
|
||||
## OCR model list(V2.0, updated on 2021.1.20)
|
||||
**Note** : Compared with [models 1.1](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md), which are trained with static graph programming paradigm, models 2.0 are the dynamic graph trained version and achieve close performance.
|
||||
|
||||
- [1. Text Detection Model](#Detection)
|
||||
|
@ -51,12 +51,73 @@ The downloadable models provided by PaddleOCR include `inference model`, `traine
|
|||
<a name="Multilingual"></a>
|
||||
#### Multilingual Recognition Model(Updating...)
|
||||
|
||||
**Note:** The configuration file of the new multi language model is generated by code. You can use the `--help` parameter to check which multi language are supported by current PaddleOCR.
|
||||
|
||||
```bash
|
||||
# The code needs to run in the specified directory
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
python3 generate_multi_language_configs.py --help
|
||||
```
|
||||
|
||||
Take the Italian configuration file as an example:
|
||||
##### 1.Generate Italian configuration file to test the model provided
|
||||
you can generate the default configuration file through the following command, and use the default language dictionary provided by paddleocr for prediction.
|
||||
```bash
|
||||
# The code needs to run in the specified directory
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
# Set the required language configuration file through -l or --language parameter
|
||||
# This command will write the default parameter to the configuration file.
|
||||
python3 generate_multi_language_configs.py -l it
|
||||
```
|
||||
##### 2. Generate Italian configuration file to train your own data
|
||||
If you want to train your own model, you can prepare the training set file, verification set file, dictionary file and training data path. Here we assume that the Italian training set, verification set, dictionary and training data path are:
|
||||
- Training set:{your/path/}PaddleOCR/train_data/train_list.txt
|
||||
- Validation set: {your/path/}PaddleOCR/train_data/val_list.txt
|
||||
- Use the default dictionary provided by paddleocr:{your/path/}PaddleOCR/ppocr/utils/dict/it_dict.txt
|
||||
- Training data path:{your/path/}PaddleOCR/train_data
|
||||
```bash
|
||||
# The code needs to run in the specified directory
|
||||
cd {your/path/}PaddleOCR/configs/rec/multi_language/
|
||||
# The -l or --language parameter is required
|
||||
# --train modify train_list path
|
||||
# --val modify eval_list path
|
||||
# --data_dir modify data dir
|
||||
# -o modify default parameters
|
||||
# --dict Change the dictionary path. The example uses the default dictionary path, so that this parameter can be empty.
|
||||
python3 generate_multi_language_configs.py -l it \
|
||||
--train {path/to/train_list} \
|
||||
--val {path/to/val_list} \
|
||||
--data_dir {path/to/data_dir} \
|
||||
-o Global.use_gpu=False
|
||||
```
|
||||
|model name|description|config|model size|download|
|
||||
| --- | --- | --- | --- | --- |
|
||||
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
|
||||
| german_mobile_v2.0_rec |Lightweight model for French recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
|
||||
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
|
||||
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
|
||||
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
|
||||
| xi_mobile_v2.0_rec |Lightweight model for Spanish recognition|rec_xi_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/xi_mobile_v2.0_rec_train.tar) |
|
||||
| pu_mobile_v2.0_rec |Lightweight model for Portuguese recognition|rec_pu_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/pu_mobile_v2.0_rec_train.tar) |
|
||||
| ru_mobile_v2.0_rec |Lightweight model for Russia recognition|rec_ru_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ru_mobile_v2.0_rec_train.tar) |
|
||||
| ar_mobile_v2.0_rec |Lightweight model for Arabic recognition|rec_ar_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ar_mobile_v2.0_rec_train.tar) |
|
||||
| hi_mobile_v2.0_rec |Lightweight model for Hindi recognition|rec_hi_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/hi_mobile_v2.0_rec_train.tar) |
|
||||
| chinese_cht_mobile_v2.0_rec |Lightweight model for chinese traditional recognition|rec_chinese_cht_lite_train.yml|5.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_train.tar) |
|
||||
| ug_mobile_v2.0_rec |Lightweight model for Uyghur recognition|rec_ug_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ug_mobile_v2.0_rec_train.tar) |
|
||||
| fa_mobile_v2.0_rec |Lightweight model for Persian recognition|rec_fa_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/fa_mobile_v2.0_rec_train.tar) |
|
||||
| ur_mobile_v2.0_rec |Lightweight model for Urdu recognition|rec_ur_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ur_mobile_v2.0_rec_train.tar) |
|
||||
| rs_mobile_v2.0_rec |Lightweight model for Serbian(latin) recognition|rec_rs_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rs_mobile_v2.0_rec_train.tar) |
|
||||
| oc_mobile_v2.0_rec |Lightweight model for Occitan recognition|rec_oc_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/oc_mobile_v2.0_rec_train.tar) |
|
||||
| mr_mobile_v2.0_rec |Lightweight model for Marathi recognition|rec_mr_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/mr_mobile_v2.0_rec_train.tar) |
|
||||
| ne_mobile_v2.0_rec |Lightweight model for Nepali recognition|rec_ne_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ne_mobile_v2.0_rec_train.tar) |
|
||||
| rsc_mobile_v2.0_rec |Lightweight model for Serbian(cyrillic) recognition|rec_rsc_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/rsc_mobile_v2.0_rec_train.tar) |
|
||||
| bg_mobile_v2.0_rec |Lightweight model for Bulgarian recognition|rec_bg_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/bg_mobile_v2.0_rec_train.tar) |
|
||||
| uk_mobile_v2.0_rec |Lightweight model for Ukranian recognition|rec_uk_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/uk_mobile_v2.0_rec_train.tar) |
|
||||
| be_mobile_v2.0_rec |Lightweight model for Belarusian recognition|rec_be_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/be_mobile_v2.0_rec_train.tar) |
|
||||
| te_mobile_v2.0_rec |Lightweight model for Telugu recognition|rec_te_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_train.tar) |
|
||||
| ka_mobile_v2.0_rec |Lightweight model for Kannada recognition|rec_ka_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_train.tar) |
|
||||
| ta_mobile_v2.0_rec |Lightweight model for Tamil recognition|rec_ta_lite_train.yml|2.63M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_train.tar) |
|
||||
|
||||
|
||||
<a name="Angle"></a>
|
||||
### 3. Text Angle Classification Model
|
||||
|
|
After Width: | Height: | Size: 4.7 KiB |
After Width: | Height: | Size: 3.6 KiB |
After Width: | Height: | Size: 6.4 KiB |
After Width: | Height: | Size: 4.5 KiB |
After Width: | Height: | Size: 6.8 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 65 KiB |
After Width: | Height: | Size: 73 KiB |
After Width: | Height: | Size: 5.7 KiB |
After Width: | Height: | Size: 6.5 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 9.4 KiB |
After Width: | Height: | Size: 6.7 KiB |
After Width: | Height: | Size: 7.8 KiB |
After Width: | Height: | Size: 4.4 KiB |
After Width: | Height: | Size: 2.8 KiB |
After Width: | Height: | Size: 5.4 KiB |
After Width: | Height: | Size: 4.1 KiB |
After Width: | Height: | Size: 2.7 KiB |
After Width: | Height: | Size: 6.5 KiB |
After Width: | Height: | Size: 3.9 KiB |
After Width: | Height: | Size: 5.3 KiB |
After Width: | Height: | Size: 15 KiB |
After Width: | Height: | Size: 11 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 6.0 KiB |
After Width: | Height: | Size: 4.5 KiB |
After Width: | Height: | Size: 6.6 KiB |
After Width: | Height: | Size: 4.1 KiB |
After Width: | Height: | Size: 4.4 KiB |
After Width: | Height: | Size: 8.5 KiB |
After Width: | Height: | Size: 7.0 KiB |
After Width: | Height: | Size: 6.1 KiB |
After Width: | Height: | Size: 5.2 KiB |
After Width: | Height: | Size: 8.2 KiB |
After Width: | Height: | Size: 6.0 KiB |
After Width: | Height: | Size: 4.4 KiB |
After Width: | Height: | Size: 13 KiB |
After Width: | Height: | Size: 5.0 KiB |
After Width: | Height: | Size: 4.7 KiB |
After Width: | Height: | Size: 5.6 KiB |
After Width: | Height: | Size: 4.8 KiB |
BIN
doc/joinus.PNG
Before Width: | Height: | Size: 109 KiB After Width: | Height: | Size: 107 KiB |
|
@ -290,7 +290,9 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
image_file = img
|
||||
img, flag = check_and_read_gif(image_file)
|
||||
if not flag:
|
||||
img = cv2.imread(image_file)
|
||||
with open(image_file, 'rb') as f:
|
||||
np_arr = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
|
||||
if img is None:
|
||||
logger.error("error in loading image:{}".format(image_file))
|
||||
return None
|
||||
|
|
|
@ -24,11 +24,11 @@ __all__ = ['SASTProcessTrain']
|
|||
|
||||
class SASTProcessTrain(object):
|
||||
def __init__(self,
|
||||
image_shape = [512, 512],
|
||||
min_crop_size = 24,
|
||||
min_crop_side_ratio = 0.3,
|
||||
min_text_size = 10,
|
||||
max_text_size = 512,
|
||||
image_shape=[512, 512],
|
||||
min_crop_size=24,
|
||||
min_crop_side_ratio=0.3,
|
||||
min_text_size=10,
|
||||
max_text_size=512,
|
||||
**kwargs):
|
||||
self.input_size = image_shape[1]
|
||||
self.min_crop_size = min_crop_size
|
||||
|
@ -42,12 +42,10 @@ class SASTProcessTrain(object):
|
|||
:param poly:
|
||||
:return:
|
||||
"""
|
||||
edge = [
|
||||
(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])
|
||||
]
|
||||
edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
|
||||
(poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
|
||||
(poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
|
||||
(poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
|
||||
return np.sum(edge) / 2.
|
||||
|
||||
def gen_quad_from_poly(self, poly):
|
||||
|
@ -57,7 +55,8 @@ class SASTProcessTrain(object):
|
|||
point_num = poly.shape[0]
|
||||
min_area_quad = np.zeros((4, 2), dtype=np.float32)
|
||||
if True:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
|
@ -102,23 +101,33 @@ class SASTProcessTrain(object):
|
|||
if p_area > 0:
|
||||
if tag == False:
|
||||
print('poly in wrong direction')
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1), :]
|
||||
tag = True # reversed cases should be ignore
|
||||
poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
|
||||
1), :]
|
||||
quad = quad[(0, 3, 2, 1), :]
|
||||
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] - quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2])
|
||||
len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
|
||||
quad[2])
|
||||
len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
|
||||
quad[2])
|
||||
hv_tag = 1
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
|
||||
if len_w * 2.0 < len_h:
|
||||
hv_tag = 0
|
||||
|
||||
validated_polys.append(poly)
|
||||
validated_tags.append(tag)
|
||||
hv_tags.append(hv_tag)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(hv_tags)
|
||||
return np.array(validated_polys), np.array(validated_tags), np.array(
|
||||
hv_tags)
|
||||
|
||||
def crop_area(self, im, polys, tags, hv_tags, crop_background=False, max_tries=25):
|
||||
def crop_area(self,
|
||||
im,
|
||||
polys,
|
||||
tags,
|
||||
hv_tags,
|
||||
crop_background=False,
|
||||
max_tries=25):
|
||||
"""
|
||||
make random crop from the input image
|
||||
:param im:
|
||||
|
@ -137,10 +146,10 @@ class SASTProcessTrain(object):
|
|||
poly = np.round(poly, decimals=0).astype(np.int32)
|
||||
minx = np.min(poly[:, 0])
|
||||
maxx = np.max(poly[:, 0])
|
||||
w_array[minx + pad_w: maxx + pad_w] = 1
|
||||
w_array[minx + pad_w:maxx + pad_w] = 1
|
||||
miny = np.min(poly[:, 1])
|
||||
maxy = np.max(poly[:, 1])
|
||||
h_array[miny + pad_h: maxy + pad_h] = 1
|
||||
h_array[miny + pad_h:maxy + pad_h] = 1
|
||||
# ensure the cropped area not across a text
|
||||
h_axis = np.where(h_array == 0)[0]
|
||||
w_axis = np.where(w_array == 0)[0]
|
||||
|
@ -166,17 +175,18 @@ class SASTProcessTrain(object):
|
|||
if polys.shape[0] != 0:
|
||||
poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
|
||||
& (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
|
||||
selected_polys = np.where(np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
selected_polys = np.where(
|
||||
np.sum(poly_axis_in_area, axis=1) == 4)[0]
|
||||
else:
|
||||
selected_polys = []
|
||||
if len(selected_polys) == 0:
|
||||
# no text in this area
|
||||
if crop_background:
|
||||
return im[ymin : ymax + 1, xmin : xmax + 1, :], \
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
|
||||
polys[selected_polys], tags[selected_polys], hv_tags[selected_polys]
|
||||
else:
|
||||
continue
|
||||
im = im[ymin: ymax + 1, xmin: xmax + 1, :]
|
||||
im = im[ymin:ymax + 1, xmin:xmax + 1, :]
|
||||
polys = polys[selected_polys]
|
||||
tags = tags[selected_polys]
|
||||
hv_tags = hv_tags[selected_polys]
|
||||
|
@ -192,18 +202,28 @@ class SASTProcessTrain(object):
|
|||
width_list = []
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
quad_w = (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
width_list.append(quad_w)
|
||||
height_list.append(quad_h)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
norm_width = max(sum(width_list) / (len(width_list) + 1e-6), 1.0)
|
||||
average_height = max(sum(height_list) / (len(height_list) + 1e-6), 1.0)
|
||||
|
||||
for quad in poly_quads:
|
||||
direct_vector_full = ((quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(map(float, [direct_vector[0], direct_vector[1], 1.0 / (average_height + 1e-6)]))
|
||||
cv2.fillPoly(direction_map, quad.round().astype(np.int32)[np.newaxis, :, :], direction_label)
|
||||
direct_vector_full = (
|
||||
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
|
||||
direct_vector = direct_vector_full / (
|
||||
np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
|
||||
direction_label = tuple(
|
||||
map(float, [
|
||||
direct_vector[0], direct_vector[1], 1.0 / (average_height +
|
||||
1e-6)
|
||||
]))
|
||||
cv2.fillPoly(direction_map,
|
||||
quad.round().astype(np.int32)[np.newaxis, :, :],
|
||||
direction_label)
|
||||
return direction_map
|
||||
|
||||
def calculate_average_height(self, poly_quads):
|
||||
|
@ -211,13 +231,19 @@ class SASTProcessTrain(object):
|
|||
"""
|
||||
height_list = []
|
||||
for quad in poly_quads:
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
quad_h = (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[2] - quad[1])) / 2.0
|
||||
height_list.append(quad_h)
|
||||
average_height = max(sum(height_list) / len(height_list), 1.0)
|
||||
return average_height
|
||||
|
||||
def generate_tcl_label(self, hw, polys, tags, ds_ratio,
|
||||
tcl_ratio=0.3, shrink_ratio_of_width=0.15):
|
||||
def generate_tcl_label(self,
|
||||
hw,
|
||||
polys,
|
||||
tags,
|
||||
ds_ratio,
|
||||
tcl_ratio=0.3,
|
||||
shrink_ratio_of_width=0.15):
|
||||
"""
|
||||
Generate polygon.
|
||||
"""
|
||||
|
@ -225,21 +251,30 @@ class SASTProcessTrain(object):
|
|||
h, w = int(h * ds_ratio), int(w * ds_ratio)
|
||||
polys = polys * ds_ratio
|
||||
|
||||
score_map = np.zeros((h, w,), dtype=np.float32)
|
||||
score_map = np.zeros(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
tbo_map = np.zeros((h, w, 5), dtype=np.float32)
|
||||
training_mask = np.ones((h, w,), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape([1, 1, 3]).astype(np.float32)
|
||||
training_mask = np.ones(
|
||||
(
|
||||
h,
|
||||
w, ), dtype=np.float32)
|
||||
direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
|
||||
[1, 1, 3]).astype(np.float32)
|
||||
|
||||
for poly_idx, poly_tag in enumerate(zip(polys, tags)):
|
||||
poly = poly_tag[0]
|
||||
poly = poly_tag[0]
|
||||
tag = poly_tag[1]
|
||||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
|
||||
or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
|
||||
|
@ -247,25 +282,37 @@ class SASTProcessTrain(object):
|
|||
|
||||
if tag:
|
||||
# continue
|
||||
cv2.fillPoly(training_mask, poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
cv2.fillPoly(training_mask,
|
||||
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
|
||||
else:
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
tcl_quads = self.poly2quads(tcl_poly)
|
||||
poly_quads = self.poly2quads(poly)
|
||||
# stcl map
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(tcl_quads, shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
stcl_quads, quad_index = self.shrink_poly_along_width(
|
||||
tcl_quads,
|
||||
shrink_ratio_of_width=shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0 / tcl_ratio)
|
||||
# generate tcl map
|
||||
cv2.fillPoly(score_map, np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
cv2.fillPoly(score_map,
|
||||
np.round(stcl_quads).astype(np.int32), 1.0)
|
||||
|
||||
# generate tbo map
|
||||
for idx, quad in enumerate(stcl_quads):
|
||||
quad_mask = np.zeros((h, w), dtype=np.float32)
|
||||
quad_mask = cv2.fillPoly(quad_mask, np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]], quad_mask, tbo_map)
|
||||
quad_mask = cv2.fillPoly(
|
||||
quad_mask,
|
||||
np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
|
||||
tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
|
||||
quad_mask, tbo_map)
|
||||
return score_map, tbo_map, training_mask
|
||||
|
||||
def generate_tvo_and_tco(self, hw, polys, tags, tcl_ratio=0.3, ds_ratio=0.25):
|
||||
def generate_tvo_and_tco(self,
|
||||
hw,
|
||||
polys,
|
||||
tags,
|
||||
tcl_ratio=0.3,
|
||||
ds_ratio=0.25):
|
||||
"""
|
||||
Generate tcl map, tvo map and tbo map.
|
||||
"""
|
||||
|
@ -297,35 +344,44 @@ class SASTProcessTrain(object):
|
|||
|
||||
# generate min_area_quad
|
||||
min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
|
||||
min_area_quad_h = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
min_area_quad_h = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
|
||||
np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
|
||||
min_area_quad_w = 0.5 * (
|
||||
np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
|
||||
np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
|
||||
|
||||
# generate tcl map and text, 128 * 128
|
||||
tcl_poly = self.poly2tcl(poly, tcl_ratio)
|
||||
|
||||
# generate poly_tv_xy_map
|
||||
for idx in range(4):
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
cv2.fillPoly(
|
||||
poly_tv_xy_map[2 * idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 0], 0), w)))
|
||||
cv2.fillPoly(
|
||||
poly_tv_xy_map[2 * idx + 1],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(min(max(min_area_quad[idx, 1], 0), h)))
|
||||
|
||||
# generate poly_tc_xy_map
|
||||
for idx in range(2):
|
||||
cv2.fillPoly(poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), float(center_point[idx]))
|
||||
cv2.fillPoly(
|
||||
poly_tc_xy_map[idx],
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(center_point[idx]))
|
||||
|
||||
# generate poly_short_edge_map
|
||||
cv2.fillPoly(poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
cv2.fillPoly(
|
||||
poly_short_edge_map,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
float(max(min(min_area_quad_h, min_area_quad_w), 1.0)))
|
||||
|
||||
# generate poly_mask and training_mask
|
||||
cv2.fillPoly(poly_mask, np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32), 1)
|
||||
cv2.fillPoly(poly_mask,
|
||||
np.round(tcl_poly[np.newaxis, :, :]).astype(np.int32),
|
||||
1)
|
||||
|
||||
tvo_map *= poly_mask
|
||||
tvo_map[:8] -= poly_tv_xy_map
|
||||
|
@ -356,7 +412,8 @@ class SASTProcessTrain(object):
|
|||
elif point_num > 4:
|
||||
vector_1 = poly[0] - poly[1]
|
||||
vector_2 = poly[1] - poly[2]
|
||||
cos_theta = np.dot(vector_1, vector_2) / (np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
cos_theta = np.dot(vector_1, vector_2) / (
|
||||
np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
|
||||
theta = np.arccos(np.round(cos_theta, decimals=4))
|
||||
|
||||
if abs(theta) > (70 / 180 * math.pi):
|
||||
|
@ -374,7 +431,8 @@ class SASTProcessTrain(object):
|
|||
min_area_quad = poly
|
||||
center_point = np.sum(poly, axis=0) / 4
|
||||
else:
|
||||
rect = cv2.minAreaRect(poly.astype(np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
rect = cv2.minAreaRect(poly.astype(
|
||||
np.int32)) # (center (x,y), (width, height), angle of rotation)
|
||||
center_point = rect[0]
|
||||
box = np.array(cv2.boxPoints(rect))
|
||||
|
||||
|
@ -394,16 +452,23 @@ class SASTProcessTrain(object):
|
|||
|
||||
return min_area_quad, center_point
|
||||
|
||||
def shrink_quad_along_width(self, quad, begin_width_ratio=0., end_width_ratio=1.):
|
||||
def shrink_quad_along_width(self,
|
||||
quad,
|
||||
begin_width_ratio=0.,
|
||||
end_width_ratio=1.):
|
||||
"""
|
||||
Generate shrink_quad_along_width.
|
||||
"""
|
||||
ratio_pair = np.array([[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
ratio_pair = np.array(
|
||||
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
|
||||
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
|
||||
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
|
||||
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
|
||||
|
||||
def shrink_poly_along_width(self, quads, shrink_ratio_of_width, expand_height_ratio=1.0):
|
||||
def shrink_poly_along_width(self,
|
||||
quads,
|
||||
shrink_ratio_of_width,
|
||||
expand_height_ratio=1.0):
|
||||
"""
|
||||
shrink poly with given length.
|
||||
"""
|
||||
|
@ -421,22 +486,28 @@ class SASTProcessTrain(object):
|
|||
upper_edge_list.append(upper_edge_len)
|
||||
|
||||
# length of left edge and right edge.
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][2]) * expand_height_ratio
|
||||
left_length = np.linalg.norm(quads[0][0] - quads[0][
|
||||
3]) * expand_height_ratio
|
||||
right_length = np.linalg.norm(quads[-1][1] - quads[-1][
|
||||
2]) * expand_height_ratio
|
||||
|
||||
shrink_length = min(left_length, right_length, sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
shrink_length = min(left_length, right_length,
|
||||
sum(upper_edge_list)) * shrink_ratio_of_width
|
||||
# shrinking length
|
||||
upper_len_left = shrink_length
|
||||
upper_len_right = sum(upper_edge_list) - shrink_length
|
||||
|
||||
left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
|
||||
left_quad = self.shrink_quad_along_width(quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
left_quad = self.shrink_quad_along_width(
|
||||
quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
|
||||
right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
|
||||
right_quad = self.shrink_quad_along_width(quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
right_quad = self.shrink_quad_along_width(
|
||||
quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
|
||||
|
||||
out_quad_list = []
|
||||
if left_idx == right_idx:
|
||||
out_quad_list.append([left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
out_quad_list.append(
|
||||
[left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
|
||||
else:
|
||||
out_quad_list.append(left_quad)
|
||||
for idx in range(left_idx + 1, right_idx):
|
||||
|
@ -500,7 +571,8 @@ class SASTProcessTrain(object):
|
|||
"""
|
||||
Generate center line by poly clock-wise point. (4, 2)
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
|
||||
p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
|
||||
return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
|
||||
|
@ -509,12 +581,14 @@ class SASTProcessTrain(object):
|
|||
"""
|
||||
Generate center line by poly clock-wise point.
|
||||
"""
|
||||
ratio_pair = np.array([[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
ratio_pair = np.array(
|
||||
[[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
|
||||
tcl_poly = np.zeros_like(poly)
|
||||
point_num = poly.shape[0]
|
||||
|
||||
for idx in range(point_num // 2):
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]) * ratio_pair
|
||||
point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
|
||||
) * ratio_pair
|
||||
tcl_poly[idx] = point_pair[0]
|
||||
tcl_poly[point_num - 1 - idx] = point_pair[1]
|
||||
return tcl_poly
|
||||
|
@ -527,8 +601,10 @@ class SASTProcessTrain(object):
|
|||
up_line = self.line_cross_two_point(quad[0], quad[1])
|
||||
lower_line = self.line_cross_two_point(quad[3], quad[2])
|
||||
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[2] - quad[3]))
|
||||
quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
|
||||
np.linalg.norm(quad[1] - quad[2]))
|
||||
quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
|
||||
np.linalg.norm(quad[2] - quad[3]))
|
||||
|
||||
# average angle of left and right line.
|
||||
angle = self.average_angle(quad)
|
||||
|
@ -565,7 +641,8 @@ class SASTProcessTrain(object):
|
|||
quad_num = point_num // 2 - 1
|
||||
for idx in range(quad_num):
|
||||
# reshape and adjust to clock-wise
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
|
||||
).reshape(4, 2)[[0, 2, 3, 1]])
|
||||
|
||||
return np.array(quad_list)
|
||||
|
||||
|
@ -579,7 +656,8 @@ class SASTProcessTrain(object):
|
|||
return None
|
||||
|
||||
h, w, _ = im.shape
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(text_polys, text_tags, (h, w))
|
||||
text_polys, text_tags, hv_tags = self.check_and_validate_polys(
|
||||
text_polys, text_tags, (h, w))
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
|
@ -591,7 +669,7 @@ class SASTProcessTrain(object):
|
|||
if np.random.rand() < 0.5:
|
||||
asp_scale = 1.0 / asp_scale
|
||||
asp_scale = math.sqrt(asp_scale)
|
||||
|
||||
|
||||
asp_wx = asp_scale
|
||||
asp_hy = 1.0 / asp_scale
|
||||
im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
|
||||
|
@ -610,7 +688,7 @@ class SASTProcessTrain(object):
|
|||
#no background
|
||||
im, text_polys, text_tags, hv_tags = self.crop_area(im, \
|
||||
text_polys, text_tags, hv_tags, crop_background=False)
|
||||
|
||||
|
||||
if text_polys.shape[0] == 0:
|
||||
return None
|
||||
#continue for all ignore case
|
||||
|
@ -621,17 +699,18 @@ class SASTProcessTrain(object):
|
|||
return None
|
||||
#resize image
|
||||
std_ratio = float(self.input_size) / max(new_w, new_h)
|
||||
rand_scales = np.array([0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rand_scales = np.array(
|
||||
[0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
|
||||
rz_scale = std_ratio * np.random.choice(rand_scales)
|
||||
im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
|
||||
text_polys[:, :, 0] *= rz_scale
|
||||
text_polys[:, :, 1] *= rz_scale
|
||||
|
||||
|
||||
#add gaussian blur
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
ks = np.random.permutation(5)[0] + 1
|
||||
ks = int(ks/2)*2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
ks = int(ks / 2) * 2 + 1
|
||||
im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
|
||||
#add brighter
|
||||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 + np.random.rand() * 0.5)
|
||||
|
@ -640,13 +719,14 @@ class SASTProcessTrain(object):
|
|||
if np.random.rand() < 0.1 * 0.5:
|
||||
im = im * (1.0 - np.random.rand() * 0.5)
|
||||
im = np.clip(im, 0.0, 255.0)
|
||||
|
||||
|
||||
# Padding the im to [input_size, input_size]
|
||||
new_h, new_w, _ = im.shape
|
||||
if min(new_w, new_h) < self.input_size * 0.5:
|
||||
return None
|
||||
|
||||
im_padded = np.ones((self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded = np.ones(
|
||||
(self.input_size, self.input_size, 3), dtype=np.float32)
|
||||
im_padded[:, :, 2] = 0.485 * 255
|
||||
im_padded[:, :, 1] = 0.456 * 255
|
||||
im_padded[:, :, 0] = 0.406 * 255
|
||||
|
@ -661,24 +741,29 @@ class SASTProcessTrain(object):
|
|||
sw = int(np.random.rand() * del_w)
|
||||
|
||||
# Padding
|
||||
im_padded[sh: sh + new_h, sw: sw + new_w, :] = im.copy()
|
||||
im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
|
||||
text_polys[:, :, 0] += sw
|
||||
text_polys[:, :, 1] += sh
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label((self.input_size, self.input_size),
|
||||
text_polys, text_tags, 0.25)
|
||||
|
||||
score_map, border_map, training_mask = self.generate_tcl_label(
|
||||
(self.input_size, self.input_size), text_polys, text_tags, 0.25)
|
||||
|
||||
# SAST head
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco((self.input_size, self.input_size), text_polys, text_tags, tcl_ratio=0.3, ds_ratio=0.25)
|
||||
tvo_map, tco_map = self.generate_tvo_and_tco(
|
||||
(self.input_size, self.input_size),
|
||||
text_polys,
|
||||
text_tags,
|
||||
tcl_ratio=0.3,
|
||||
ds_ratio=0.25)
|
||||
# print("test--------tvo_map shape:", tvo_map.shape)
|
||||
|
||||
im_padded[:, :, 2] -= 0.485 * 255
|
||||
im_padded[:, :, 1] -= 0.456 * 255
|
||||
im_padded[:, :, 0] -= 0.406 * 255
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
im_padded[:, :, 2] /= (255.0 * 0.229)
|
||||
im_padded[:, :, 1] /= (255.0 * 0.224)
|
||||
im_padded[:, :, 0] /= (255.0 * 0.225)
|
||||
im_padded = im_padded.transpose((2, 0, 1))
|
||||
|
||||
data['image'] = im_padded[::-1, :, :]
|
||||
data['score_map'] = score_map[np.newaxis, :, :]
|
||||
|
@ -686,4 +771,4 @@ class SASTProcessTrain(object):
|
|||
data['training_mask'] = training_mask[np.newaxis, :, :]
|
||||
data['tvo_map'] = tvo_map.transpose((2, 0, 1))
|
||||
data['tco_map'] = tco_map.transpose((2, 0, 1))
|
||||
return data
|
||||
return data
|
||||
|
|
|
@ -24,7 +24,9 @@ class BaseRecLabelDecode(object):
|
|||
character_type='ch',
|
||||
use_space_char=False):
|
||||
support_character_type = [
|
||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
|
||||
'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean', 'it',
|
||||
'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg',
|
||||
'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'ne'
|
||||
]
|
||||
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
|
||||
support_character_type, character_type)
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
a
|
||||
r
|
||||
b
|
||||
i
|
||||
c
|
||||
_
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
0
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
v
|
||||
l
|
||||
6
|
||||
3
|
||||
9
|
||||
.
|
||||
j
|
||||
p
|
||||
ا
|
||||
ل
|
||||
م
|
||||
ر
|
||||
ج
|
||||
و
|
||||
ح
|
||||
ي
|
||||
ة
|
||||
5
|
||||
8
|
||||
7
|
||||
أ
|
||||
ب
|
||||
ض
|
||||
4
|
||||
ك
|
||||
س
|
||||
ه
|
||||
ث
|
||||
ن
|
||||
ط
|
||||
ع
|
||||
ت
|
||||
غ
|
||||
خ
|
||||
ف
|
||||
ئ
|
||||
ز
|
||||
إ
|
||||
د
|
||||
ص
|
||||
ظ
|
||||
ذ
|
||||
ش
|
||||
ى
|
||||
ق
|
||||
ؤ
|
||||
آ
|
||||
ء
|
||||
s
|
||||
e
|
||||
n
|
||||
w
|
||||
t
|
||||
u
|
||||
z
|
||||
d
|
||||
A
|
||||
N
|
||||
G
|
||||
h
|
||||
o
|
||||
E
|
||||
T
|
||||
H
|
||||
O
|
||||
B
|
||||
y
|
||||
F
|
||||
U
|
||||
J
|
||||
X
|
||||
W
|
||||
P
|
||||
Z
|
||||
M
|
||||
k
|
||||
q
|
||||
Y
|
||||
Q
|
||||
D
|
||||
f
|
||||
K
|
||||
x
|
||||
'
|
||||
%
|
||||
-
|
||||
#
|
||||
@
|
||||
!
|
||||
&
|
||||
$
|
||||
,
|
||||
:
|
||||
é
|
||||
?
|
||||
+
|
||||
É
|
||||
(
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
b
|
||||
e
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
2
|
||||
0
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
6
|
||||
9
|
||||
4
|
||||
3
|
||||
.
|
||||
j
|
||||
p
|
||||
п
|
||||
а
|
||||
з
|
||||
б
|
||||
у
|
||||
г
|
||||
н
|
||||
ц
|
||||
ь
|
||||
8
|
||||
м
|
||||
л
|
||||
і
|
||||
о
|
||||
ў
|
||||
ы
|
||||
7
|
||||
5
|
||||
М
|
||||
х
|
||||
с
|
||||
р
|
||||
ф
|
||||
я
|
||||
е
|
||||
д
|
||||
ж
|
||||
ю
|
||||
ч
|
||||
й
|
||||
к
|
||||
Д
|
||||
в
|
||||
Б
|
||||
т
|
||||
І
|
||||
ш
|
||||
ё
|
||||
э
|
||||
К
|
||||
Л
|
||||
Н
|
||||
А
|
||||
Ж
|
||||
Г
|
||||
В
|
||||
П
|
||||
З
|
||||
Е
|
||||
О
|
||||
Р
|
||||
С
|
||||
У
|
||||
Ё
|
||||
Й
|
||||
Т
|
||||
Ч
|
||||
Э
|
||||
Ц
|
||||
Ю
|
||||
Ш
|
||||
Ф
|
||||
Х
|
||||
Я
|
||||
Ь
|
||||
Ы
|
||||
Ў
|
||||
s
|
||||
c
|
||||
n
|
||||
w
|
||||
M
|
||||
o
|
||||
t
|
||||
T
|
||||
E
|
||||
A
|
||||
B
|
||||
u
|
||||
h
|
||||
y
|
||||
k
|
||||
r
|
||||
H
|
||||
d
|
||||
Y
|
||||
O
|
||||
U
|
||||
F
|
||||
f
|
||||
x
|
||||
D
|
||||
G
|
||||
N
|
||||
K
|
||||
P
|
||||
z
|
||||
J
|
||||
X
|
||||
W
|
||||
Z
|
||||
Q
|
||||
%
|
||||
-
|
||||
q
|
||||
@
|
||||
'
|
||||
!
|
||||
#
|
||||
&
|
||||
,
|
||||
:
|
||||
$
|
||||
(
|
||||
?
|
||||
é
|
||||
+
|
||||
É
|
||||
|
|
@ -0,0 +1,140 @@
|
|||
!
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
_
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
É
|
||||
é
|
||||
А
|
||||
Б
|
||||
В
|
||||
Г
|
||||
Д
|
||||
Е
|
||||
Ж
|
||||
З
|
||||
И
|
||||
Й
|
||||
К
|
||||
Л
|
||||
М
|
||||
Н
|
||||
О
|
||||
П
|
||||
Р
|
||||
С
|
||||
Т
|
||||
У
|
||||
Ф
|
||||
Х
|
||||
Ц
|
||||
Ч
|
||||
Ш
|
||||
Щ
|
||||
Ъ
|
||||
Ю
|
||||
Я
|
||||
а
|
||||
б
|
||||
в
|
||||
г
|
||||
д
|
||||
е
|
||||
ж
|
||||
з
|
||||
и
|
||||
й
|
||||
к
|
||||
л
|
||||
м
|
||||
н
|
||||
о
|
||||
п
|
||||
р
|
||||
с
|
||||
т
|
||||
у
|
||||
ф
|
||||
х
|
||||
ц
|
||||
ч
|
||||
ш
|
||||
щ
|
||||
ъ
|
||||
ь
|
||||
ю
|
||||
я
|
||||
|
|
@ -0,0 +1,136 @@
|
|||
f
|
||||
a
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
3
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
v
|
||||
l
|
||||
6
|
||||
8
|
||||
5
|
||||
.
|
||||
j
|
||||
p
|
||||
و
|
||||
د
|
||||
ر
|
||||
ك
|
||||
ن
|
||||
ش
|
||||
ه
|
||||
ا
|
||||
4
|
||||
9
|
||||
ی
|
||||
ج
|
||||
ِ
|
||||
7
|
||||
غ
|
||||
ل
|
||||
س
|
||||
ز
|
||||
ّ
|
||||
ت
|
||||
ک
|
||||
گ
|
||||
ي
|
||||
م
|
||||
ب
|
||||
ف
|
||||
چ
|
||||
خ
|
||||
ق
|
||||
ژ
|
||||
آ
|
||||
ص
|
||||
پ
|
||||
َ
|
||||
ع
|
||||
ئ
|
||||
ح
|
||||
ٔ
|
||||
ض
|
||||
ُ
|
||||
ذ
|
||||
أ
|
||||
ى
|
||||
ط
|
||||
ظ
|
||||
ث
|
||||
ة
|
||||
ً
|
||||
ء
|
||||
ؤ
|
||||
ْ
|
||||
ۀ
|
||||
إ
|
||||
ٍ
|
||||
ٌ
|
||||
ٰ
|
||||
ٓ
|
||||
ٱ
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
N
|
||||
E
|
||||
W
|
||||
Y
|
||||
D
|
||||
O
|
||||
H
|
||||
A
|
||||
d
|
||||
z
|
||||
r
|
||||
T
|
||||
G
|
||||
o
|
||||
t
|
||||
x
|
||||
h
|
||||
b
|
||||
B
|
||||
M
|
||||
Z
|
||||
u
|
||||
P
|
||||
F
|
||||
y
|
||||
q
|
||||
U
|
||||
K
|
||||
k
|
||||
J
|
||||
Q
|
||||
'
|
||||
X
|
||||
#
|
||||
?
|
||||
%
|
||||
$
|
||||
,
|
||||
:
|
||||
&
|
||||
!
|
||||
-
|
||||
(
|
||||
É
|
||||
@
|
||||
é
|
||||
+
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
|
||||
!
|
||||
"
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
|
@ -72,7 +74,7 @@ l
|
|||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
|
@ -83,45 +85,59 @@ w
|
|||
x
|
||||
y
|
||||
z
|
||||
¡
|
||||
¢
|
||||
£
|
||||
¤
|
||||
¥
|
||||
¦
|
||||
§
|
||||
¨
|
||||
|
||||
°
|
||||
´
|
||||
µ
|
||||
·
|
||||
º
|
||||
¿
|
||||
Á
|
||||
Ä
|
||||
Å
|
||||
É
|
||||
Ï
|
||||
Ô
|
||||
Ö
|
||||
Ü
|
||||
ß
|
||||
à
|
||||
á
|
||||
â
|
||||
ã
|
||||
ä
|
||||
å
|
||||
æ
|
||||
ç
|
||||
è
|
||||
é
|
||||
ê
|
||||
ë
|
||||
í
|
||||
ï
|
||||
ñ
|
||||
ò
|
||||
ó
|
||||
ô
|
||||
ö
|
||||
ø
|
||||
ù
|
||||
ú
|
||||
û
|
||||
ü
|
||||
ō
|
||||
Š
|
||||
Ÿ
|
||||
ʒ
|
||||
β
|
||||
δ
|
||||
з
|
||||
Ṡ
|
||||
‘
|
||||
€
|
||||
©
|
||||
ª
|
||||
«
|
||||
¬
|
||||
|
||||
®
|
||||
¯
|
||||
°
|
||||
±
|
||||
²
|
||||
³
|
||||
´
|
||||
µ
|
||||
¶
|
||||
·
|
||||
¸
|
||||
¹
|
||||
º
|
||||
»
|
||||
¼
|
||||
½
|
||||
¿
|
||||
Â
|
||||
Ã
|
||||
Å
|
||||
Ê
|
||||
Î
|
||||
Ð
|
||||
á
|
||||
â
|
||||
å
|
||||
æ
|
||||
é
|
||||
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
|
||||
!
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
_
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
É
|
||||
é
|
||||
ँ
|
||||
ं
|
||||
ः
|
||||
अ
|
||||
आ
|
||||
इ
|
||||
ई
|
||||
उ
|
||||
ऊ
|
||||
ऋ
|
||||
ए
|
||||
ऐ
|
||||
ऑ
|
||||
ओ
|
||||
औ
|
||||
क
|
||||
ख
|
||||
ग
|
||||
घ
|
||||
ङ
|
||||
च
|
||||
छ
|
||||
ज
|
||||
झ
|
||||
ञ
|
||||
ट
|
||||
ठ
|
||||
ड
|
||||
ढ
|
||||
ण
|
||||
त
|
||||
थ
|
||||
द
|
||||
ध
|
||||
न
|
||||
प
|
||||
फ
|
||||
ब
|
||||
भ
|
||||
म
|
||||
य
|
||||
र
|
||||
ल
|
||||
ळ
|
||||
व
|
||||
श
|
||||
ष
|
||||
स
|
||||
ह
|
||||
़
|
||||
ा
|
||||
ि
|
||||
ी
|
||||
ु
|
||||
ू
|
||||
ृ
|
||||
ॅ
|
||||
े
|
||||
ै
|
||||
ॉ
|
||||
ो
|
||||
ौ
|
||||
्
|
||||
क़
|
||||
ख़
|
||||
ग़
|
||||
ज़
|
||||
ड़
|
||||
ढ़
|
||||
फ़
|
||||
०
|
||||
१
|
||||
२
|
||||
३
|
||||
४
|
||||
५
|
||||
६
|
||||
७
|
||||
८
|
||||
९
|
||||
॰
|
|
@ -0,0 +1,118 @@
|
|||
i
|
||||
t
|
||||
_
|
||||
m
|
||||
g
|
||||
/
|
||||
5
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
7
|
||||
8
|
||||
9
|
||||
6
|
||||
.
|
||||
j
|
||||
p
|
||||
|
||||
e
|
||||
r
|
||||
o
|
||||
d
|
||||
s
|
||||
n
|
||||
3
|
||||
4
|
||||
P
|
||||
u
|
||||
c
|
||||
A
|
||||
-
|
||||
,
|
||||
"
|
||||
z
|
||||
h
|
||||
f
|
||||
b
|
||||
q
|
||||
ì
|
||||
'
|
||||
à
|
||||
O
|
||||
è
|
||||
G
|
||||
ù
|
||||
é
|
||||
ò
|
||||
;
|
||||
F
|
||||
E
|
||||
B
|
||||
N
|
||||
H
|
||||
k
|
||||
:
|
||||
U
|
||||
T
|
||||
X
|
||||
D
|
||||
K
|
||||
?
|
||||
[
|
||||
M
|
||||
|
||||
x
|
||||
y
|
||||
(
|
||||
)
|
||||
W
|
||||
ö
|
||||
º
|
||||
w
|
||||
]
|
||||
Q
|
||||
J
|
||||
+
|
||||
ü
|
||||
!
|
||||
È
|
||||
á
|
||||
%
|
||||
=
|
||||
»
|
||||
ñ
|
||||
Ö
|
||||
Y
|
||||
ä
|
||||
í
|
||||
Z
|
||||
«
|
||||
@
|
||||
ó
|
||||
ø
|
||||
ï
|
||||
ú
|
||||
ê
|
||||
ç
|
||||
Á
|
||||
É
|
||||
Å
|
||||
ß
|
||||
{
|
||||
}
|
||||
&
|
||||
`
|
||||
û
|
||||
î
|
||||
#
|
||||
$
|
|
@ -0,0 +1,153 @@
|
|||
k
|
||||
a
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
2
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
0
|
||||
v
|
||||
l
|
||||
6
|
||||
4
|
||||
8
|
||||
.
|
||||
j
|
||||
p
|
||||
ಗ
|
||||
ು
|
||||
ಣ
|
||||
ಪ
|
||||
ಡ
|
||||
ಿ
|
||||
ಸ
|
||||
ಲ
|
||||
ಾ
|
||||
ದ
|
||||
್
|
||||
7
|
||||
5
|
||||
3
|
||||
ವ
|
||||
ಷ
|
||||
ಬ
|
||||
ಹ
|
||||
ೆ
|
||||
9
|
||||
ಅ
|
||||
ಳ
|
||||
ನ
|
||||
ರ
|
||||
ಉ
|
||||
ಕ
|
||||
ಎ
|
||||
ೇ
|
||||
ಂ
|
||||
ೈ
|
||||
ೊ
|
||||
ೀ
|
||||
ಯ
|
||||
ೋ
|
||||
ತ
|
||||
ಶ
|
||||
ಭ
|
||||
ಧ
|
||||
ಚ
|
||||
ಜ
|
||||
ೂ
|
||||
ಮ
|
||||
ಒ
|
||||
ೃ
|
||||
ಥ
|
||||
ಇ
|
||||
ಟ
|
||||
ಖ
|
||||
ಆ
|
||||
ಞ
|
||||
ಫ
|
||||
-
|
||||
ಢ
|
||||
ಊ
|
||||
ಓ
|
||||
ಐ
|
||||
ಃ
|
||||
ಘ
|
||||
ಝ
|
||||
ೌ
|
||||
ಠ
|
||||
ಛ
|
||||
ಔ
|
||||
ಏ
|
||||
ಈ
|
||||
ಋ
|
||||
೨
|
||||
೦
|
||||
೧
|
||||
೮
|
||||
೯
|
||||
೪
|
||||
,
|
||||
೫
|
||||
೭
|
||||
೩
|
||||
೬
|
||||
ಙ
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
o
|
||||
u
|
||||
t
|
||||
d
|
||||
E
|
||||
A
|
||||
T
|
||||
B
|
||||
Z
|
||||
N
|
||||
G
|
||||
O
|
||||
q
|
||||
z
|
||||
r
|
||||
x
|
||||
P
|
||||
K
|
||||
M
|
||||
J
|
||||
U
|
||||
D
|
||||
f
|
||||
F
|
||||
h
|
||||
b
|
||||
W
|
||||
Y
|
||||
y
|
||||
H
|
||||
X
|
||||
Q
|
||||
'
|
||||
#
|
||||
&
|
||||
!
|
||||
@
|
||||
$
|
||||
:
|
||||
%
|
||||
é
|
||||
É
|
||||
(
|
||||
?
|
||||
+
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
|
||||
!
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
_
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
É
|
||||
é
|
||||
ँ
|
||||
ं
|
||||
ः
|
||||
अ
|
||||
आ
|
||||
इ
|
||||
ई
|
||||
उ
|
||||
ऊ
|
||||
ए
|
||||
ऐ
|
||||
ऑ
|
||||
ओ
|
||||
औ
|
||||
क
|
||||
ख
|
||||
ग
|
||||
घ
|
||||
च
|
||||
छ
|
||||
ज
|
||||
झ
|
||||
ञ
|
||||
ट
|
||||
ठ
|
||||
ड
|
||||
ढ
|
||||
ण
|
||||
त
|
||||
थ
|
||||
द
|
||||
ध
|
||||
न
|
||||
प
|
||||
फ
|
||||
ब
|
||||
भ
|
||||
म
|
||||
य
|
||||
र
|
||||
ऱ
|
||||
ल
|
||||
ळ
|
||||
व
|
||||
श
|
||||
ष
|
||||
स
|
||||
ह
|
||||
़
|
||||
ा
|
||||
ि
|
||||
ी
|
||||
ु
|
||||
ू
|
||||
ृ
|
||||
ॅ
|
||||
े
|
||||
ै
|
||||
ॉ
|
||||
ो
|
||||
ौ
|
||||
्
|
||||
०
|
||||
१
|
||||
२
|
||||
३
|
||||
४
|
||||
५
|
||||
६
|
||||
७
|
||||
८
|
||||
९
|
|
@ -0,0 +1,153 @@
|
|||
|
||||
!
|
||||
#
|
||||
$
|
||||
%
|
||||
&
|
||||
'
|
||||
(
|
||||
+
|
||||
,
|
||||
-
|
||||
.
|
||||
/
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
:
|
||||
?
|
||||
@
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
_
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
É
|
||||
é
|
||||
ः
|
||||
अ
|
||||
आ
|
||||
इ
|
||||
ई
|
||||
उ
|
||||
ऊ
|
||||
ऋ
|
||||
ए
|
||||
ऐ
|
||||
ओ
|
||||
औ
|
||||
क
|
||||
ख
|
||||
ग
|
||||
घ
|
||||
ङ
|
||||
च
|
||||
छ
|
||||
ज
|
||||
झ
|
||||
ञ
|
||||
ट
|
||||
ठ
|
||||
ड
|
||||
ढ
|
||||
ण
|
||||
त
|
||||
थ
|
||||
द
|
||||
ध
|
||||
न
|
||||
ऩ
|
||||
प
|
||||
फ
|
||||
ब
|
||||
भ
|
||||
म
|
||||
य
|
||||
र
|
||||
ऱ
|
||||
ल
|
||||
व
|
||||
श
|
||||
ष
|
||||
स
|
||||
ह
|
||||
़
|
||||
ा
|
||||
ि
|
||||
ी
|
||||
ु
|
||||
ू
|
||||
ृ
|
||||
े
|
||||
ै
|
||||
ो
|
||||
ौ
|
||||
्
|
||||
॒
|
||||
ॠ
|
||||
।
|
||||
०
|
||||
१
|
||||
२
|
||||
३
|
||||
४
|
||||
५
|
||||
६
|
||||
७
|
||||
८
|
||||
९
|
|
@ -0,0 +1,96 @@
|
|||
o
|
||||
c
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
2
|
||||
0
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
4
|
||||
3
|
||||
.
|
||||
j
|
||||
p
|
||||
r
|
||||
e
|
||||
è
|
||||
t
|
||||
9
|
||||
7
|
||||
5
|
||||
8
|
||||
n
|
||||
'
|
||||
b
|
||||
s
|
||||
6
|
||||
q
|
||||
u
|
||||
á
|
||||
d
|
||||
ò
|
||||
à
|
||||
h
|
||||
z
|
||||
f
|
||||
ï
|
||||
í
|
||||
A
|
||||
ç
|
||||
x
|
||||
ó
|
||||
é
|
||||
P
|
||||
O
|
||||
Ò
|
||||
ü
|
||||
k
|
||||
À
|
||||
F
|
||||
-
|
||||
ú
|
||||
|
||||
æ
|
||||
Á
|
||||
D
|
||||
E
|
||||
w
|
||||
K
|
||||
T
|
||||
N
|
||||
y
|
||||
U
|
||||
Z
|
||||
G
|
||||
B
|
||||
J
|
||||
H
|
||||
M
|
||||
W
|
||||
Y
|
||||
X
|
||||
Q
|
||||
%
|
||||
$
|
||||
,
|
||||
@
|
||||
&
|
||||
!
|
||||
:
|
||||
(
|
||||
#
|
||||
?
|
||||
+
|
||||
É
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
p
|
||||
u
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
8
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
6
|
||||
7
|
||||
4
|
||||
5
|
||||
.
|
||||
j
|
||||
|
||||
q
|
||||
e
|
||||
s
|
||||
t
|
||||
ã
|
||||
o
|
||||
x
|
||||
9
|
||||
c
|
||||
n
|
||||
r
|
||||
z
|
||||
ç
|
||||
õ
|
||||
3
|
||||
A
|
||||
U
|
||||
d
|
||||
º
|
||||
ô
|
||||
|
||||
,
|
||||
E
|
||||
;
|
||||
ó
|
||||
á
|
||||
b
|
||||
D
|
||||
?
|
||||
ú
|
||||
ê
|
||||
-
|
||||
h
|
||||
P
|
||||
f
|
||||
à
|
||||
N
|
||||
í
|
||||
O
|
||||
M
|
||||
G
|
||||
É
|
||||
é
|
||||
â
|
||||
F
|
||||
:
|
||||
T
|
||||
Á
|
||||
"
|
||||
Q
|
||||
)
|
||||
W
|
||||
J
|
||||
B
|
||||
H
|
||||
(
|
||||
ö
|
||||
%
|
||||
Ö
|
||||
«
|
||||
w
|
||||
K
|
||||
y
|
||||
!
|
||||
k
|
||||
]
|
||||
'
|
||||
Z
|
||||
+
|
||||
Ç
|
||||
Õ
|
||||
Y
|
||||
À
|
||||
X
|
||||
µ
|
||||
»
|
||||
ª
|
||||
Í
|
||||
ü
|
||||
ä
|
||||
´
|
||||
è
|
||||
ñ
|
||||
ß
|
||||
ï
|
||||
Ú
|
||||
ë
|
||||
Ô
|
||||
Ï
|
||||
Ó
|
||||
[
|
||||
Ì
|
||||
<
|
||||
Â
|
||||
ò
|
||||
§
|
||||
³
|
||||
ø
|
||||
å
|
||||
#
|
||||
$
|
||||
&
|
||||
@
|
|
@ -0,0 +1,91 @@
|
|||
r
|
||||
s
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
v
|
||||
a
|
||||
l
|
||||
7
|
||||
5
|
||||
8
|
||||
6
|
||||
.
|
||||
j
|
||||
p
|
||||
|
||||
t
|
||||
d
|
||||
9
|
||||
3
|
||||
e
|
||||
š
|
||||
4
|
||||
k
|
||||
u
|
||||
ć
|
||||
c
|
||||
n
|
||||
đ
|
||||
o
|
||||
z
|
||||
č
|
||||
b
|
||||
ž
|
||||
f
|
||||
Z
|
||||
T
|
||||
h
|
||||
M
|
||||
F
|
||||
O
|
||||
Š
|
||||
B
|
||||
H
|
||||
A
|
||||
E
|
||||
Đ
|
||||
Ž
|
||||
D
|
||||
P
|
||||
G
|
||||
Č
|
||||
K
|
||||
U
|
||||
N
|
||||
J
|
||||
Ć
|
||||
w
|
||||
y
|
||||
W
|
||||
x
|
||||
Y
|
||||
X
|
||||
q
|
||||
Q
|
||||
#
|
||||
&
|
||||
$
|
||||
,
|
||||
-
|
||||
%
|
||||
'
|
||||
@
|
||||
!
|
||||
:
|
||||
?
|
||||
(
|
||||
É
|
||||
é
|
||||
+
|
|
@ -0,0 +1,134 @@
|
|||
r
|
||||
s
|
||||
c
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
5
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
9
|
||||
7
|
||||
8
|
||||
.
|
||||
j
|
||||
p
|
||||
м
|
||||
а
|
||||
с
|
||||
и
|
||||
р
|
||||
ћ
|
||||
е
|
||||
ш
|
||||
3
|
||||
4
|
||||
о
|
||||
г
|
||||
н
|
||||
з
|
||||
в
|
||||
л
|
||||
6
|
||||
т
|
||||
ж
|
||||
у
|
||||
к
|
||||
п
|
||||
њ
|
||||
д
|
||||
ч
|
||||
С
|
||||
ј
|
||||
ф
|
||||
ц
|
||||
љ
|
||||
х
|
||||
О
|
||||
И
|
||||
А
|
||||
б
|
||||
Ш
|
||||
К
|
||||
ђ
|
||||
џ
|
||||
М
|
||||
В
|
||||
З
|
||||
Д
|
||||
Р
|
||||
У
|
||||
Н
|
||||
Т
|
||||
Б
|
||||
?
|
||||
П
|
||||
Х
|
||||
Ј
|
||||
Ц
|
||||
Г
|
||||
Љ
|
||||
Л
|
||||
Ф
|
||||
e
|
||||
n
|
||||
w
|
||||
E
|
||||
F
|
||||
A
|
||||
N
|
||||
f
|
||||
o
|
||||
b
|
||||
M
|
||||
G
|
||||
t
|
||||
y
|
||||
W
|
||||
k
|
||||
P
|
||||
u
|
||||
H
|
||||
B
|
||||
T
|
||||
z
|
||||
h
|
||||
O
|
||||
Y
|
||||
d
|
||||
U
|
||||
K
|
||||
D
|
||||
x
|
||||
X
|
||||
J
|
||||
Z
|
||||
Q
|
||||
q
|
||||
'
|
||||
-
|
||||
@
|
||||
é
|
||||
#
|
||||
!
|
||||
,
|
||||
%
|
||||
$
|
||||
:
|
||||
&
|
||||
+
|
||||
(
|
||||
É
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
к
|
||||
в
|
||||
а
|
||||
з
|
||||
и
|
||||
у
|
||||
р
|
||||
о
|
||||
н
|
||||
я
|
||||
х
|
||||
п
|
||||
л
|
||||
ы
|
||||
г
|
||||
е
|
||||
т
|
||||
м
|
||||
д
|
||||
ж
|
||||
ш
|
||||
ь
|
||||
с
|
||||
ё
|
||||
б
|
||||
й
|
||||
ч
|
||||
ю
|
||||
ц
|
||||
щ
|
||||
М
|
||||
э
|
||||
ф
|
||||
А
|
||||
ъ
|
||||
С
|
||||
Ф
|
||||
Ю
|
||||
В
|
||||
К
|
||||
Т
|
||||
Н
|
||||
О
|
||||
Э
|
||||
У
|
||||
И
|
||||
Г
|
||||
Л
|
||||
Р
|
||||
Д
|
||||
Б
|
||||
Ш
|
||||
П
|
||||
З
|
||||
Х
|
||||
Е
|
||||
Ж
|
||||
Я
|
||||
Ц
|
||||
Ч
|
||||
Й
|
||||
Щ
|
||||
0
|
||||
1
|
||||
2
|
||||
3
|
||||
4
|
||||
5
|
||||
6
|
||||
7
|
||||
8
|
||||
9
|
||||
a
|
||||
b
|
||||
c
|
||||
d
|
||||
e
|
||||
f
|
||||
g
|
||||
h
|
||||
i
|
||||
j
|
||||
k
|
||||
l
|
||||
m
|
||||
n
|
||||
o
|
||||
p
|
||||
q
|
||||
r
|
||||
s
|
||||
t
|
||||
u
|
||||
v
|
||||
w
|
||||
x
|
||||
y
|
||||
z
|
||||
A
|
||||
B
|
||||
C
|
||||
D
|
||||
E
|
||||
F
|
||||
G
|
||||
H
|
||||
I
|
||||
J
|
||||
K
|
||||
L
|
||||
M
|
||||
N
|
||||
O
|
||||
P
|
||||
Q
|
||||
R
|
||||
S
|
||||
T
|
||||
U
|
||||
V
|
||||
W
|
||||
X
|
||||
Y
|
||||
Z
|
||||
|
|
@ -0,0 +1,128 @@
|
|||
t
|
||||
a
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
3
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
l
|
||||
9
|
||||
7
|
||||
8
|
||||
.
|
||||
j
|
||||
p
|
||||
ப
|
||||
ூ
|
||||
த
|
||||
ம
|
||||
ி
|
||||
வ
|
||||
ர
|
||||
்
|
||||
ந
|
||||
ோ
|
||||
ன
|
||||
6
|
||||
ஆ
|
||||
ற
|
||||
ல
|
||||
5
|
||||
ள
|
||||
ா
|
||||
ொ
|
||||
ழ
|
||||
ு
|
||||
4
|
||||
ெ
|
||||
ண
|
||||
க
|
||||
ட
|
||||
ை
|
||||
ே
|
||||
ச
|
||||
ய
|
||||
ஒ
|
||||
இ
|
||||
அ
|
||||
ங
|
||||
உ
|
||||
ீ
|
||||
ஞ
|
||||
எ
|
||||
ஓ
|
||||
ஃ
|
||||
ஜ
|
||||
ஷ
|
||||
ஸ
|
||||
ஏ
|
||||
ஊ
|
||||
ஹ
|
||||
ஈ
|
||||
ஐ
|
||||
ௌ
|
||||
ஔ
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
F
|
||||
T
|
||||
O
|
||||
P
|
||||
K
|
||||
A
|
||||
N
|
||||
G
|
||||
Y
|
||||
E
|
||||
M
|
||||
H
|
||||
U
|
||||
B
|
||||
o
|
||||
b
|
||||
D
|
||||
d
|
||||
r
|
||||
W
|
||||
u
|
||||
y
|
||||
f
|
||||
X
|
||||
k
|
||||
q
|
||||
h
|
||||
J
|
||||
z
|
||||
Z
|
||||
Q
|
||||
x
|
||||
-
|
||||
'
|
||||
$
|
||||
,
|
||||
%
|
||||
@
|
||||
é
|
||||
!
|
||||
#
|
||||
+
|
||||
É
|
||||
&
|
||||
:
|
||||
(
|
||||
?
|
||||
|
|
@ -0,0 +1,151 @@
|
|||
t
|
||||
e
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
5
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
3
|
||||
4
|
||||
8
|
||||
9
|
||||
.
|
||||
j
|
||||
p
|
||||
త
|
||||
ె
|
||||
ర
|
||||
క
|
||||
్
|
||||
ి
|
||||
ం
|
||||
చ
|
||||
ే
|
||||
ద
|
||||
ు
|
||||
7
|
||||
6
|
||||
ఉ
|
||||
ా
|
||||
మ
|
||||
ట
|
||||
ో
|
||||
వ
|
||||
ప
|
||||
ల
|
||||
శ
|
||||
ఆ
|
||||
య
|
||||
ై
|
||||
భ
|
||||
'
|
||||
ీ
|
||||
గ
|
||||
ూ
|
||||
డ
|
||||
ధ
|
||||
హ
|
||||
న
|
||||
జ
|
||||
స
|
||||
[
|
||||
|
||||
ష
|
||||
అ
|
||||
ణ
|
||||
ఫ
|
||||
బ
|
||||
ఎ
|
||||
;
|
||||
ళ
|
||||
థ
|
||||
ొ
|
||||
ఠ
|
||||
ృ
|
||||
ఒ
|
||||
ఇ
|
||||
ః
|
||||
ఊ
|
||||
ఖ
|
||||
-
|
||||
ఐ
|
||||
ఘ
|
||||
ౌ
|
||||
ఏ
|
||||
ఈ
|
||||
ఛ
|
||||
,
|
||||
ఓ
|
||||
ఞ
|
||||
|
|
||||
?
|
||||
:
|
||||
ఢ
|
||||
"
|
||||
(
|
||||
”
|
||||
!
|
||||
+
|
||||
)
|
||||
*
|
||||
=
|
||||
&
|
||||
“
|
||||
€
|
||||
]
|
||||
£
|
||||
$
|
||||
s
|
||||
c
|
||||
n
|
||||
w
|
||||
k
|
||||
J
|
||||
G
|
||||
u
|
||||
d
|
||||
r
|
||||
E
|
||||
o
|
||||
h
|
||||
y
|
||||
b
|
||||
f
|
||||
B
|
||||
M
|
||||
O
|
||||
T
|
||||
N
|
||||
D
|
||||
P
|
||||
A
|
||||
F
|
||||
x
|
||||
W
|
||||
Y
|
||||
U
|
||||
H
|
||||
K
|
||||
X
|
||||
z
|
||||
Z
|
||||
Q
|
||||
q
|
||||
É
|
||||
%
|
||||
#
|
||||
@
|
||||
é
|
|
@ -0,0 +1,114 @@
|
|||
u
|
||||
g
|
||||
_
|
||||
i
|
||||
m
|
||||
/
|
||||
1
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
v
|
||||
a
|
||||
l
|
||||
8
|
||||
5
|
||||
3
|
||||
6
|
||||
9
|
||||
.
|
||||
j
|
||||
p
|
||||
|
||||
ق
|
||||
ا
|
||||
پ
|
||||
ل
|
||||
4
|
||||
7
|
||||
ئ
|
||||
ى
|
||||
ش
|
||||
ت
|
||||
ي
|
||||
ك
|
||||
د
|
||||
ف
|
||||
ر
|
||||
و
|
||||
ن
|
||||
ب
|
||||
ە
|
||||
خ
|
||||
ې
|
||||
چ
|
||||
ۇ
|
||||
ز
|
||||
س
|
||||
م
|
||||
ۋ
|
||||
گ
|
||||
ڭ
|
||||
ۆ
|
||||
ۈ
|
||||
ج
|
||||
غ
|
||||
ھ
|
||||
ژ
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
P
|
||||
E
|
||||
D
|
||||
U
|
||||
d
|
||||
r
|
||||
b
|
||||
y
|
||||
B
|
||||
o
|
||||
O
|
||||
Y
|
||||
N
|
||||
T
|
||||
k
|
||||
t
|
||||
h
|
||||
A
|
||||
H
|
||||
F
|
||||
z
|
||||
W
|
||||
K
|
||||
G
|
||||
M
|
||||
f
|
||||
Z
|
||||
X
|
||||
Q
|
||||
J
|
||||
x
|
||||
q
|
||||
-
|
||||
!
|
||||
%
|
||||
#
|
||||
?
|
||||
:
|
||||
$
|
||||
,
|
||||
&
|
||||
'
|
||||
É
|
||||
@
|
||||
é
|
||||
(
|
||||
+
|
|
@ -0,0 +1,142 @@
|
|||
u
|
||||
k
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
6
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
v
|
||||
a
|
||||
l
|
||||
7
|
||||
9
|
||||
.
|
||||
j
|
||||
p
|
||||
в
|
||||
і
|
||||
д
|
||||
п
|
||||
о
|
||||
н
|
||||
с
|
||||
т
|
||||
ю
|
||||
4
|
||||
5
|
||||
3
|
||||
а
|
||||
и
|
||||
м
|
||||
е
|
||||
р
|
||||
ч
|
||||
у
|
||||
Б
|
||||
з
|
||||
л
|
||||
к
|
||||
8
|
||||
А
|
||||
В
|
||||
г
|
||||
є
|
||||
б
|
||||
ь
|
||||
х
|
||||
ґ
|
||||
ш
|
||||
ц
|
||||
ф
|
||||
я
|
||||
щ
|
||||
ж
|
||||
Г
|
||||
Х
|
||||
У
|
||||
Т
|
||||
Е
|
||||
І
|
||||
Н
|
||||
П
|
||||
З
|
||||
Л
|
||||
Ю
|
||||
С
|
||||
Д
|
||||
М
|
||||
К
|
||||
Р
|
||||
Ф
|
||||
О
|
||||
Ц
|
||||
И
|
||||
Я
|
||||
Ч
|
||||
Ш
|
||||
Ж
|
||||
Є
|
||||
Ґ
|
||||
Ь
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
A
|
||||
P
|
||||
r
|
||||
E
|
||||
t
|
||||
o
|
||||
h
|
||||
d
|
||||
y
|
||||
M
|
||||
G
|
||||
N
|
||||
F
|
||||
B
|
||||
T
|
||||
D
|
||||
U
|
||||
O
|
||||
W
|
||||
Z
|
||||
f
|
||||
H
|
||||
Y
|
||||
b
|
||||
K
|
||||
z
|
||||
x
|
||||
Q
|
||||
X
|
||||
q
|
||||
J
|
||||
$
|
||||
-
|
||||
'
|
||||
#
|
||||
&
|
||||
%
|
||||
?
|
||||
:
|
||||
!
|
||||
,
|
||||
+
|
||||
@
|
||||
(
|
||||
é
|
||||
É
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
u
|
||||
r
|
||||
_
|
||||
i
|
||||
m
|
||||
g
|
||||
/
|
||||
3
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
0
|
||||
1
|
||||
v
|
||||
a
|
||||
l
|
||||
9
|
||||
7
|
||||
8
|
||||
.
|
||||
j
|
||||
p
|
||||
|
||||
چ
|
||||
ٹ
|
||||
پ
|
||||
ا
|
||||
ئ
|
||||
ی
|
||||
ے
|
||||
4
|
||||
6
|
||||
و
|
||||
ل
|
||||
ن
|
||||
ڈ
|
||||
ھ
|
||||
ک
|
||||
ت
|
||||
ش
|
||||
ف
|
||||
ق
|
||||
ر
|
||||
د
|
||||
5
|
||||
ب
|
||||
ج
|
||||
خ
|
||||
ہ
|
||||
س
|
||||
ز
|
||||
غ
|
||||
ڑ
|
||||
ں
|
||||
آ
|
||||
م
|
||||
ؤ
|
||||
ط
|
||||
ص
|
||||
ح
|
||||
ع
|
||||
گ
|
||||
ث
|
||||
ض
|
||||
ذ
|
||||
ۓ
|
||||
ِ
|
||||
ء
|
||||
ظ
|
||||
ً
|
||||
ي
|
||||
ُ
|
||||
ۃ
|
||||
أ
|
||||
ٰ
|
||||
ە
|
||||
ژ
|
||||
ۂ
|
||||
ة
|
||||
ّ
|
||||
ك
|
||||
ه
|
||||
s
|
||||
c
|
||||
e
|
||||
n
|
||||
w
|
||||
o
|
||||
d
|
||||
t
|
||||
D
|
||||
M
|
||||
T
|
||||
U
|
||||
E
|
||||
b
|
||||
P
|
||||
h
|
||||
y
|
||||
W
|
||||
H
|
||||
A
|
||||
x
|
||||
B
|
||||
O
|
||||
N
|
||||
G
|
||||
Y
|
||||
Q
|
||||
F
|
||||
k
|
||||
K
|
||||
q
|
||||
J
|
||||
Z
|
||||
f
|
||||
z
|
||||
X
|
||||
'
|
||||
@
|
||||
&
|
||||
!
|
||||
,
|
||||
:
|
||||
$
|
||||
-
|
||||
#
|
||||
?
|
||||
%
|
||||
é
|
||||
+
|
||||
(
|
||||
É
|
|
@ -0,0 +1,110 @@
|
|||
x
|
||||
i
|
||||
_
|
||||
m
|
||||
g
|
||||
/
|
||||
1
|
||||
0
|
||||
I
|
||||
L
|
||||
S
|
||||
V
|
||||
R
|
||||
C
|
||||
2
|
||||
v
|
||||
a
|
||||
l
|
||||
3
|
||||
6
|
||||
4
|
||||
5
|
||||
.
|
||||
j
|
||||
p
|
||||
|
||||
Q
|
||||
u
|
||||
e
|
||||
r
|
||||
o
|
||||
8
|
||||
7
|
||||
n
|
||||
c
|
||||
9
|
||||
t
|
||||
b
|
||||
é
|
||||
q
|
||||
d
|
||||
ó
|
||||
y
|
||||
F
|
||||
s
|
||||
,
|
||||
O
|
||||
í
|
||||
T
|
||||
f
|
||||
"
|
||||
U
|
||||
M
|
||||
h
|
||||
:
|
||||
P
|
||||
H
|
||||
A
|
||||
E
|
||||
D
|
||||
z
|
||||
N
|
||||
á
|
||||
ñ
|
||||
ú
|
||||
%
|
||||
;
|
||||
è
|
||||
+
|
||||
Y
|
||||
-
|
||||
B
|
||||
G
|
||||
(
|
||||
)
|
||||
¿
|
||||
?
|
||||
w
|
||||
¡
|
||||
!
|
||||
X
|
||||
É
|
||||
K
|
||||
k
|
||||
Á
|
||||
ü
|
||||
Ú
|
||||
«
|
||||
»
|
||||
J
|
||||
'
|
||||
ö
|
||||
W
|
||||
Z
|
||||
º
|
||||
Ö
|
||||
|
||||
[
|
||||
]
|
||||
Ç
|
||||
ç
|
||||
à
|
||||
ä
|
||||
û
|
||||
ò
|
||||
Í
|
||||
ê
|
||||
ô
|
||||
ø
|
||||
ª
|