fix bugs
This commit is contained in:
parent
46ac85ad8a
commit
d3c50fda3c
|
@ -1,6 +1,8 @@
|
||||||
import yaml
|
import yaml
|
||||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||||
import os.path
|
import os.path
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|
||||||
support_list = {
|
support_list = {
|
||||||
'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
|
'it':'italian', 'xi':'spanish', 'pu':'portuguese', 'ru':'russian', 'ar':'arabic',
|
||||||
|
@ -16,6 +18,7 @@ You can download it from \
|
||||||
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
|
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/configs/rec/multi_language/"
|
||||||
|
|
||||||
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
|
global_config = yaml.load(open("./rec_multi_language_lite_train.yml", 'rb'), Loader=yaml.Loader)
|
||||||
|
project_path = os.path.abspath(os.path.join(os.getcwd(), "../../../"))
|
||||||
|
|
||||||
class ArgsParser(ArgumentParser):
|
class ArgsParser(ArgumentParser):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -32,7 +35,7 @@ class ArgsParser(ArgumentParser):
|
||||||
self.add_argument(
|
self.add_argument(
|
||||||
"--dict",type=str,help="you can use this command to change the dictionary default path")
|
"--dict",type=str,help="you can use this command to change the dictionary default path")
|
||||||
self.add_argument(
|
self.add_argument(
|
||||||
"--dataset_root_path",type=str,help="you can use this command to change the dataset default root path")
|
"--data_dir",type=str,help="you can use this command to change the dataset default root path")
|
||||||
|
|
||||||
def parse_args(self, argv=None):
|
def parse_args(self, argv=None):
|
||||||
args = super(ArgsParser, self).parse_args(argv)
|
args = super(ArgsParser, self).parse_args(argv)
|
||||||
|
@ -51,15 +54,19 @@ class ArgsParser(ArgumentParser):
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def _set_language(self, type):
|
def _set_language(self, type):
|
||||||
assert(type),"please use -t or --type to choose language type"
|
assert(type),"please use -l or --language to choose language type"
|
||||||
assert(
|
assert(
|
||||||
type[0] in support_list.keys()
|
type[0] in support_list.keys()
|
||||||
),"the sub_keys(-t or --type) can only be one of support list: \n{},\nbut get: {}, " \
|
),"the sub_keys(-l or --language) can only be one of support list: \n{},\nbut get: {}, " \
|
||||||
"please check your running command".format(support_list, type)
|
"please check your running command".format(support_list, type)
|
||||||
global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
|
global_config['Global']['character_dict_path'] = 'ppocr/utils/dict/{}_dict.txt'.format(type[0])
|
||||||
global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
|
global_config['Global']['save_model_dir'] = './output/rec_{}_lite'.format(type[0])
|
||||||
global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
|
global_config['Train']['dataset']['label_file_list'] = ["train_data/{}_train.txt".format(type[0])]
|
||||||
global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
|
global_config['Eval']['dataset']['label_file_list'] = ["train_data/{}_val.txt".format(type[0])]
|
||||||
|
assert(
|
||||||
|
os.path.isfile(os.path.join(project_path,global_config['Global']['character_dict_path']))
|
||||||
|
),"Loss default dictionary file {}_dict.txt.You can download it from \
|
||||||
|
https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/ppocr/utils/dict/".format(type[0])
|
||||||
return type[0]
|
return type[0]
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,27 +95,42 @@ def merge_config(config):
|
||||||
cur[sub_key] = value
|
cur[sub_key] = value
|
||||||
else:
|
else:
|
||||||
cur = cur[sub_key]
|
cur = cur[sub_key]
|
||||||
|
|
||||||
|
def loss_file(path):
|
||||||
|
if not os.path.exists(path):
|
||||||
|
logging.warning('There is no such file:{},Please do not forget to put in the specified file'.format(path))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
FLAGS = ArgsParser().parse_args()
|
FLAGS = ArgsParser().parse_args()
|
||||||
merge_config(FLAGS.opt)
|
merge_config(FLAGS.opt)
|
||||||
if FLAGS.train:
|
if FLAGS.train:
|
||||||
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
|
global_config['Train']['dataset']['label_file_list'] = [FLAGS.train]
|
||||||
|
train_label_path = os.path.join(project_path,FLAGS.train)
|
||||||
|
loss_file(train_label_path)
|
||||||
if FLAGS.val:
|
if FLAGS.val:
|
||||||
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
|
global_config['Eval']['dataset']['label_file_list'] = [FLAGS.val]
|
||||||
|
eval_label_path = os.path.join(project_path,FLAGS.val)
|
||||||
|
loss_file(Eval_label_path)
|
||||||
if FLAGS.dict:
|
if FLAGS.dict:
|
||||||
global_config['Global']['character_dict_path'] = FLAGS.dict
|
global_config['Global']['character_dict_path'] = FLAGS.dict
|
||||||
if FLAGS.dataset_root_path:
|
dict_path = os.path.join(project_path,FLAGS.dict)
|
||||||
global_config['Eval']['dataset']['data_dir'] = FLAGS.dataset_root_path
|
loss_file(dict_path)
|
||||||
global_config['Train']['dataset']['data_dir'] = FLAGS.dataset_root_path
|
if FLAGS.data_dir:
|
||||||
|
global_config['Eval']['dataset']['data_dir'] = FLAGS.data_dir
|
||||||
|
global_config['Train']['dataset']['data_dir'] = FLAGS.data_dir
|
||||||
|
data_dir = os.path.join(project_path,FLAGS.data_dir)
|
||||||
|
loss_file(data_dir)
|
||||||
|
|
||||||
|
|
||||||
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
|
save_file_path = 'rec_{}_lite_train.yml'.format(FLAGS.language)
|
||||||
if os.path.isfile(save_file_path):
|
if os.path.isfile(save_file_path):
|
||||||
os.remove(save_file_path)
|
os.remove(save_file_path)
|
||||||
with open(save_file_path, 'w') as f:
|
with open(save_file_path, 'w') as f:
|
||||||
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
|
yaml.dump(dict(global_config), f, default_flow_style=False, sort_keys=False)
|
||||||
print("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
|
logging.info("Project path is :{}".format(project_path))
|
||||||
print("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
|
logging.info("Train list path set to :{}".format(global_config['Train']['dataset']['label_file_list'][0]))
|
||||||
print("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
|
logging.info("Eval list path set to :{}".format(global_config['Eval']['dataset']['label_file_list'][0]))
|
||||||
print("Dict path set to :{}".format(global_config['Global']['character_dict_path']))
|
logging.info("Dataset root path set to :{}".format(global_config['Eval']['dataset']['data_dir']))
|
||||||
print("Config file set to :configs/rec/multi_language/{}".format(save_file_path))
|
logging.info("Dict path set to :{}".format(global_config['Global']['character_dict_path']))
|
||||||
|
logging.info("Config file set to :configs/rec/multi_language/{}".format(save_file_path))
|
||||||
|
|
|
@ -64,7 +64,7 @@ Metric:
|
||||||
Train:
|
Train:
|
||||||
dataset:
|
dataset:
|
||||||
name: SimpleDataSet
|
name: SimpleDataSet
|
||||||
data_dir: ./train_data/
|
data_dir: train_data/
|
||||||
label_file_list: ["./train_data/train_list.txt"]
|
label_file_list: ["./train_data/train_list.txt"]
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
@ -85,7 +85,7 @@ Train:
|
||||||
Eval:
|
Eval:
|
||||||
dataset:
|
dataset:
|
||||||
name: SimpleDataSet
|
name: SimpleDataSet
|
||||||
data_dir: ./train_data/
|
data_dir: train_data/
|
||||||
label_file_list: ["./train_data/val_list.txt"]
|
label_file_list: ["./train_data/val_list.txt"]
|
||||||
transforms:
|
transforms:
|
||||||
- DecodeImage: # load image
|
- DecodeImage: # load image
|
||||||
|
|
Loading…
Reference in New Issue