2021-01-20 19:06:39 +08:00
|
|
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2021-01-19 15:52:04 +08:00
|
|
|
import yaml
|
|
|
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
2021-01-19 23:46:35 +08:00
|
|
|
import os.path
|
2021-01-20 12:08:57 +08:00
|
|
|
import logging
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
2021-01-19 15:52:04 +08:00
|
|
|
|
|
|
|
support_list = {
|
2021-04-13 17:54:10 +08:00
|
|
|
'it': 'italian',
|
|
|
|
'xi': 'spanish',
|
|
|
|
'pu': 'portuguese',
|
|
|
|
'ru': 'russian',
|
|
|
|
'ar': 'arabic',
|
|
|
|
'ta': 'tamil',
|
|
|
|
'ug': 'uyghur',
|
|
|
|
'fa': 'persian',
|
|
|
|
'ur': 'urdu',
|
|
|
|
'rs': 'serbian latin',
|
|
|
|
'oc': 'occitan',
|
|
|
|
'rsc': 'serbian cyrillic',
|
|
|
|
'bg': 'bulgarian',
|
|
|
|
'uk': 'ukranian',
|
|
|
|
'be': 'belarusian',
|
|
|
|
'te': 'telugu',
|
|
|
|
'ka': 'kannada',
|
|
|
|
'chinese_cht': 'chinese tradition',
|
|
|
|
'hi': 'hindi',
|
|
|
|
'mr': 'marathi',
|
|
|
|
'ne': 'nepali',
|
2021-01-19 15:52:04 +08:00
|
|
|
}
|
2021-04-13 17:54:10 +08:00
|
|
|
|
|
|
|
latin_lang = [
|
|
|
|
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
|
|
|
|
'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
|
|
|
|
'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
|
|
|
|
'sw', 'tl', 'tr', 'uz', 'vi', 'latin'
|
|
|
|
]
|
|
|
|
arabic_lang = ['ar', 'fa', 'ug', 'ur']
|
|
|
|
cyrillic_lang = [
|
|
|
|
'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
|
|
|
|
'dar', 'inh', 'che', 'lbe', 'lez', 'tab', 'cyrillic'
|
|
|
|
]
|
|
|
|
devanagari_lang = [
|
|
|
|
'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
|
|
|
|
'sa', 'bgc', 'devanagari'
|
|
|
|
]
|
|
|
|
multi_lang = latin_lang + arabic_lang + cyrillic_lang + devanagari_lang
|
|
|
|
|
|
|
|
assert (os.path.isfile("./rec_multi_language_lite_train.yml")
|
|
|
|
), "Loss basic configuration file rec_multi_language_lite_train.yml.\
|
2021-01-19 23:46:35 +08:00
|
|
|
You can download it from \
|
|
|
|
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
|
2021-04-13 17:54:10 +08:00
|
|
|
|
|
|
|
global_config = yaml.load(
|
|
|
|
open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
|
2021-01-20 12:08:57 +08:00
|
|
|
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
|
2021-01-19 15:52:04 +08:00
|
|
|
|
2021-04-13 17:54:10 +08:00
|
|
|
|
2021-01-19 15:52:04 +08:00
|
|
|
class ArgsParser(ArgumentParser):
|
|
|
|
def __init__(self):
|
|
|
|
super(ArgsParser, self).__init__(
|
|
|
|
formatter_class=RawDescriptionHelpFormatter)
|
|
|
|
self.add_argument(
|
|
|
|
"-o", "--opt", nargs='+', help="set configuration options")
|
|
|
|
self.add_argument(
|
2021-04-13 17:54:10 +08:00
|
|
|
"-l",
|
|
|
|
"--language",
|
|
|
|
nargs='+',
|
|
|
|
help="set language type, support {}".format(support_list))
|
2021-01-19 15:52:04 +08:00
|
|
|
self.add_argument(
|
2021-04-13 17:54:10 +08:00
|
|
|
"--train",
|
|
|
|
type=str,
|
|
|
|
help="you can use this command to change the train dataset default path"
|
|
|
|
)
|
2021-01-19 15:52:04 +08:00
|
|
|
self.add_argument(
|
2021-04-13 17:54:10 +08:00
|
|
|
"--val",
|
|
|
|
type=str,
|
|
|
|
help="you can use this command to change the eval dataset default path"
|
|
|
|
)
|
2021-01-19 15:52:04 +08:00
|
|
|
self.add_argument(
|
2021-04-13 17:54:10 +08:00
|
|
|
"--dict",
|
|
|
|
type=str,
|
|
|
|
help="you can use this command to change the dictionary default path"
|
|
|
|
)
|
2021-01-19 23:46:35 +08:00
|
|
|
self.add_argument(
|
2021-04-13 17:54:10 +08:00
|
|
|
"--data_dir",
|
|
|
|
type=str,
|
|
|
|
help="you can use this command to change the dataset default root path"
|
|
|
|
)
|
2021-01-19 15:52:04 +08:00
|
|
|
|
|
|
|
def parse_args(self, argv=None):
|
|
|
|
args = super(ArgsParser, self).parse_args(argv)
|
|
|
|
args.opt = self._parse_opt(args.opt)
|
2021-01-19 23:46:35 +08:00
|
|
|
args.language = self._set_language(args.language)
|
2021-01-19 15:52:04 +08:00
|
|
|
return args
|
|
|
|
|
|
|
|
def _parse_opt(self, opts):
|
|
|
|
config = {}
|
|
|
|
if not opts:
|
|
|
|
return config
|
|
|
|
for s in opts:
|
|
|
|
s = s.strip()
|
|
|
|
k, v = s.split('=')
|
|
|
|
config[k] = yaml.load(v, Loader=yaml.Loader)
|
|
|
|
return config
|
|
|
|
|
|
|
|
def _set_language(self, type):
|
2021-04-13 17:54:10 +08:00
|
|
|
lang = type[0]
|
|
|
|
assert (type), "please use -l or --language to choose language type"
|
2021-01-19 15:52:04 +08:00
|
|
|
assert(
|
2021-04-13 17:54:10 +08:00
|
|
|
lang in support_list.keys() or lang in multi_lang
|
2021-01-20 12:08:57 +08:00
|
|
|
),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
|
2021-04-13 17:54:10 +08:00
|
|
|
"please check your running command".format(multi_lang, type)
|
|
|
|
if lang in latin_lang:
|
|
|
|
lang = "latin"
|
|
|
|
elif lang in arabic_lang:
|
|
|
|
lang = "arabic"
|
|
|
|
elif lang in cyrillic_lang:
|
|
|
|
lang = "cyrillic"
|
|
|
|
elif lang in devanagari_lang:
|
|
|
|
lang = "devanagari"
|
|
|
|
global_config['Global'][
|
|
|
|
'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(lang)
|
|
|
|
global_config['Global'][
|
|
|
|
'save_model_dir'] = './output/rec_{}_lite'.format(lang)
|
|
|
|
global_config['Train']['dataset'][
|
|
|
|
'label_file_list'] = ["train_data/{}_train.txt".format(lang)]
|
|
|
|
global_config['Eval']['dataset'][
|
|
|
|
'label_file_list'] = ["train_data/{}_val.txt".format(lang)]
|
|
|
|
global_config['Global']['character_type'] = lang
|
|
|
|
assert (
|
|
|
|
os.path.isfile(
|
|
|
|
os.path.join(project_path, global_config['Global'][
|
|
|
|
'character_dict_path']))
|
|
|
|
), "Loss default dictionary file {}_dict.txt.You can download it from \
|
|
|
|
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
|
|
|
|
lang)
|
|
|
|
return lang
|
2021-01-19 15:52:04 +08:00
|
|
|
|
|
|
|
|
2021-01-19 23:46:35 +08:00
|
|
|
def merge_config(config):
|
2021-01-19 15:52:04 +08:00
|
|
|
"""
|
|
|
|
Merge config into global config.
|
|
|
|
Args:
|
2021-01-19 23:46:35 +08:00
|
|
|
config (dict): Config to be merged.
|
2021-01-19 15:52:04 +08:00
|
|
|
Returns: global config
|
|
|
|
"""
|
2021-01-19 23:46:35 +08:00
|
|
|
for key, value in config.items():
|
2021-01-19 15:52:04 +08:00
|
|
|
if "." not in key:
|
|
|
|
if isinstance(value, dict) and key in global_config:
|
|
|
|
global_config[key].update(value)
|
|
|
|
else:
|
|
|
|
global_config[key] = value
|
|
|
|
else:
|
|
|
|
sub_keys = key.split('.')
|
|
|
|
assert (
|
|
|
|
sub_keys[0] in global_config
|
2021-01-19 23:46:35 +08:00
|
|
|
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
|
2021-01-19 15:52:04 +08:00
|
|
|
global_config.keys(), sub_keys[0])
|
|
|
|
cur = global_config[sub_keys[0]]
|
|
|
|
for idx, sub_key in enumerate(sub_keys[1:]):
|
|
|
|
if idx == len(sub_keys) - 2:
|
|
|
|
cur[sub_key] = value
|
|
|
|
else:
|
|
|
|
cur = cur[sub_key]
|
2021-04-13 17:54:10 +08:00
|
|
|
|
|
|
|
|
2021-01-20 12:08:57 +08:00
|
|
|
def loss_file(path):
|
2021-04-13 17:54:10 +08:00
|
|
|
assert (
|
|
|
|
os.path.exists(path)
|
|
|
|
), "There is no such file:{},Please do not forget to put in the specified file".format(
|
|
|
|
path)
|
|
|
|
|
2021-01-19 15:52:04 +08:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
FLAGS = ArgsParser().parse_args()
|
2021-01-19 23:46:35 +08:00
|
|
|
merge_config(FLAGS.opt)
|
2021-01-20 13:07:35 +08:00
|
|
|
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
|
|
|
|
if os.path.isfile(save_file_path):
|
|
|
|
os.remove(save_file_path)
|
2021-04-13 17:54:10 +08:00
|
|
|
|
2021-01-19 15:52:04 +08:00
|
|
|
if FLAGS.train:
|
|
|
|
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
|
2021-04-13 17:54:10 +08:00
|
|
|
train_label_path = os.path.join(project_path, FLAGS.train)
|
2021-01-20 12:08:57 +08:00
|
|
|
loss_file(train_label_path)
|
2021-01-19 15:52:04 +08:00
|
|
|
if FLAGS.val:
|
|
|
|
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
|
2021-04-13 17:54:10 +08:00
|
|
|
eval_label_path = os.path.join(project_path, FLAGS.val)
|
2021-03-03 11:09:20 +08:00
|
|
|
loss_file(eval_label_path)
|
2021-01-19 15:52:04 +08:00
|
|
|
if FLAGS.dict:
|
|
|
|
global_config['Global']['character_dict_path'] = FLAGS.dict
|
2021-04-13 17:54:10 +08:00
|
|
|
dict_path = os.path.join(project_path, FLAGS.dict)
|
2021-01-20 12:08:57 +08:00
|
|
|
loss_file(dict_path)
|
|
|
|
if FLAGS.data_dir:
|
|
|
|
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
|
|
|
|
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
|
2021-04-13 17:54:10 +08:00
|
|
|
data_dir = os.path.join(project_path, FLAGS.data_dir)
|
2021-01-20 12:08:57 +08:00
|
|
|
loss_file(data_dir)
|
2021-04-13 17:54:10 +08:00
|
|
|
|
2021-01-19 23:46:35 +08:00
|
|
|
with open(save_file_path, 'w') as f:
|
2021-04-13 17:54:10 +08:00
|
|
|
yaml.dump(
|
|
|
|
dict(global_config), f, default_flow_style=False, sort_keys=False)
|
2021-01-20 12:08:57 +08:00
|
|
|
logging.info("Project path is :{}".format(project_path))
|
2021-04-13 17:54:10 +08:00
|
|
|
logging.info("Train list path set to :{}".format(global_config['Train'][
|
|
|
|
'dataset']['label_file_list'][0]))
|
|
|
|
logging.info("Eval list path set to :{}".format(global_config['Eval'][
|
|
|
|
'dataset']['label_file_list'][0]))
|
|
|
|
logging.info("Dataset root path set to :{}".format(global_config['Eval'][
|
|
|
|
'dataset']['data_dir']))
|
|
|
|
logging.info("Dict path set to :{}".format(global_config['Global'][
|
|
|
|
'character_dict_path']))
|
|
|
|
logging.info("Config file set to :configs/rec/multi_language/{}".
|
|
|
|
format(save_file_path))
|