PaddleOCR/configs/rec/multi_language/generate_multi_language_con...

201 lines
7.3 KiB
Python
Raw Normal View History

2021-01-20 19:06:39 +08:00
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
from argparse import ArgumentParser, RawDescriptionHelpFormatter
2021-01-19 23:46:35 +08:00
import os.path
2021-01-20 12:08:57 +08:00
import logging
logging.basicConfig(level=logging.INFO)
support_list = {
2021-03-15 17:51:32 +08:00
'it': 'italian',
'es': 'spanish',
'pt': 'portuguese',
'ru': 'russian',
'ar': 'arabic',
'ta': 'tamil',
'ug': 'uyghur',
'fa': 'persian',
'ur': 'urdu',
2021-03-15 20:01:08 +08:00
'rs_latin': 'serbian latin',
2021-03-15 17:51:32 +08:00
'oc': 'occitan',
2021-03-15 20:01:08 +08:00
'rs_cyrillic': 'serbian cyrillic',
2021-03-15 17:51:32 +08:00
'bg': 'bulgarian',
'uk': 'ukranian',
'be': 'belarusian',
'te': 'telugu',
2021-03-15 20:10:45 +08:00
'kn': 'kannada',
2021-03-15 20:01:08 +08:00
'ch_tra': 'chinese tradition',
2021-03-15 17:51:32 +08:00
'hi': 'hindi',
'mr': 'marathi',
'ne': 'nepali',
}
2021-03-15 17:51:32 +08:00
assert (os.path.isfile("./rec_multi_language_lite_train.yml")
), "Loss basic configuration file rec_multi_language_lite_train.yml.\
2021-01-19 23:46:35 +08:00
You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
2021-03-15 17:51:32 +08:00
global_config = yaml.load(
open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
2021-01-20 12:08:57 +08:00
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
2021-03-15 17:51:32 +08:00
class ArgsParser(ArgumentParser):
def __init__(self):
super(ArgsParser, self).__init__(
formatter_class=RawDescriptionHelpFormatter)
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
self.add_argument(
2021-03-15 17:51:32 +08:00
"-l",
"--language",
nargs='+',
help="set language type, support {}".format(support_list))
self.add_argument(
2021-03-15 17:51:32 +08:00
"--train",
type=str,
help="you can use this command to change the train dataset default path"
)
self.add_argument(
2021-03-15 17:51:32 +08:00
"--val",
type=str,
help="you can use this command to change the eval dataset default path"
)
self.add_argument(
2021-03-15 17:51:32 +08:00
"--dict",
type=str,
help="you can use this command to change the dictionary default path"
)
2021-01-19 23:46:35 +08:00
self.add_argument(
2021-03-15 17:51:32 +08:00
"--data_dir",
type=str,
help="you can use this command to change the dataset default root path"
)
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
args.opt = self._parse_opt(args.opt)
2021-01-19 23:46:35 +08:00
args.language = self._set_language(args.language)
return args
def _parse_opt(self, opts):
config = {}
if not opts:
return config
for s in opts:
s = s.strip()
k, v = s.split('=')
config[k] = yaml.load(v, Loader=yaml.Loader)
return config
def _set_language(self, type):
2021-03-15 17:51:32 +08:00
assert (type), "please use -l or --language to choose language type"
assert(
type[0] in support_list.keys()
2021-01-20 12:08:57 +08:00
),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
"please check your running command".format(support_list, type)
2021-03-15 17:51:32 +08:00
global_config['Global'][
'character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[
0])
global_config['Global'][
'save_model_dir'] = './output/rec_{}_lite'.format(type[0])
global_config['Train']['dataset'][
'label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
global_config['Eval']['dataset'][
'label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
2021-01-20 20:06:07 +08:00
global_config['Global']['character_type'] = type[0]
2021-03-15 17:51:32 +08:00
assert (
os.path.isfile(
os.path.join(project_path, global_config['Global'][
'character_dict_path']))
), "Loss default dictionary file {}_dict.txt.You can download it from \
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(
type[0])
2021-01-19 23:46:35 +08:00
return type[0]
2021-01-19 23:46:35 +08:00
def merge_config(config):
"""
Merge config into global config.
Args:
2021-01-19 23:46:35 +08:00
config (dict): Config to be merged.
Returns: global config
"""
2021-01-19 23:46:35 +08:00
for key, value in config.items():
if "." not in key:
if isinstance(value, dict) and key in global_config:
global_config[key].update(value)
else:
global_config[key] = value
else:
sub_keys = key.split('.')
assert (
sub_keys[0] in global_config
2021-01-19 23:46:35 +08:00
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]):
if idx == len(sub_keys) - 2:
cur[sub_key] = value
else:
cur = cur[sub_key]
2021-03-15 17:51:32 +08:00
2021-01-20 12:08:57 +08:00
def loss_file(path):
2021-03-15 17:51:32 +08:00
assert (
os.path.exists(path)
), "There is no such file:{},Please do not forget to put in the specified file".format(
path)
if __name__ == '__main__':
FLAGS = ArgsParser().parse_args()
2021-01-19 23:46:35 +08:00
merge_config(FLAGS.opt)
2021-01-20 13:07:35 +08:00
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
if os.path.isfile(save_file_path):
os.remove(save_file_path)
2021-03-15 17:51:32 +08:00
if FLAGS.train:
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
2021-03-15 17:51:32 +08:00
train_label_path = os.path.join(project_path, FLAGS.train)
2021-01-20 12:08:57 +08:00
loss_file(train_label_path)
if FLAGS.val:
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
2021-03-15 17:51:32 +08:00
eval_label_path = os.path.join(project_path, FLAGS.val)
2021-03-03 11:09:20 +08:00
loss_file(eval_label_path)
if FLAGS.dict:
global_config['Global']['character_dict_path'] = FLAGS.dict
2021-03-15 17:51:32 +08:00
dict_path = os.path.join(project_path, FLAGS.dict)
2021-01-20 12:08:57 +08:00
loss_file(dict_path)
if FLAGS.data_dir:
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
2021-03-15 17:51:32 +08:00
data_dir = os.path.join(project_path, FLAGS.data_dir)
2021-01-20 12:08:57 +08:00
loss_file(data_dir)
2021-03-15 17:51:32 +08:00
2021-01-19 23:46:35 +08:00
with open(save_file_path, 'w') as f:
2021-03-15 17:51:32 +08:00
yaml.dump(
dict(global_config), f, default_flow_style=False, sort_keys=False)
2021-01-20 12:08:57 +08:00
logging.info("Project path is :{}".format(project_path))
2021-03-15 17:51:32 +08:00
logging.info("Train list path set to :{}".format(global_config['Train'][
'dataset']['label_file_list'][0]))
logging.info("Eval list path set to :{}".format(global_config['Eval'][
'dataset']['label_file_list'][0]))
logging.info("Dataset root path set to :{}".format(global_config['Eval'][
'dataset']['data_dir']))
logging.info("Dict path set to :{}".format(global_config['Global'][
'character_dict_path']))
logging.info("Config file set to :configs/rec/multi_language/{}".
format(save_file_path))