225 lines
7.1 KiB
Python
225 lines
7.1 KiB
Python
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import yaml
|
||
|
import os
|
||
|
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||
|
|
||
|
|
||
|
def override(dl, ks, v):
|
||
|
"""
|
||
|
Recursively replace dict of list
|
||
|
|
||
|
Args:
|
||
|
dl(dict or list): dict or list to be replaced
|
||
|
ks(list): list of keys
|
||
|
v(str): value to be replaced
|
||
|
"""
|
||
|
|
||
|
def str2num(v):
|
||
|
try:
|
||
|
return eval(v)
|
||
|
except Exception:
|
||
|
return v
|
||
|
|
||
|
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
|
||
|
assert len(ks) > 0, ('lenght of keys should larger than 0')
|
||
|
if isinstance(dl, list):
|
||
|
k = str2num(ks[0])
|
||
|
if len(ks) == 1:
|
||
|
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
|
||
|
dl[k] = str2num(v)
|
||
|
else:
|
||
|
override(dl[k], ks[1:], v)
|
||
|
else:
|
||
|
if len(ks) == 1:
|
||
|
#assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
|
||
|
if not ks[0] in dl:
|
||
|
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
|
||
|
dl[ks[0]] = str2num(v)
|
||
|
else:
|
||
|
assert ks[0] in dl, (
|
||
|
'({}) doesn\'t exist in {}, a new dict field is invalid'.
|
||
|
format(ks[0], dl))
|
||
|
override(dl[ks[0]], ks[1:], v)
|
||
|
|
||
|
|
||
|
def override_config(config, options=None):
|
||
|
"""
|
||
|
Recursively override the config
|
||
|
|
||
|
Args:
|
||
|
config(dict): dict to be replaced
|
||
|
options(list): list of pairs(key0.key1.idx.key2=value)
|
||
|
such as: [
|
||
|
'topk=2',
|
||
|
'VALID.transforms.1.ResizeImage.resize_short=300'
|
||
|
]
|
||
|
|
||
|
Returns:
|
||
|
config(dict): replaced config
|
||
|
"""
|
||
|
if options is not None:
|
||
|
for opt in options:
|
||
|
assert isinstance(opt, str), (
|
||
|
"option({}) should be a str".format(opt))
|
||
|
assert "=" in opt, (
|
||
|
"option({}) should contain a ="
|
||
|
"to distinguish between key and value".format(opt))
|
||
|
pair = opt.split('=')
|
||
|
assert len(pair) == 2, ("there can be only a = in the option")
|
||
|
key, value = pair
|
||
|
keys = key.split('.')
|
||
|
override(config, keys, value)
|
||
|
|
||
|
return config
|
||
|
|
||
|
|
||
|
class ArgsParser(ArgumentParser):
|
||
|
def __init__(self):
|
||
|
super(ArgsParser, self).__init__(
|
||
|
formatter_class=RawDescriptionHelpFormatter)
|
||
|
self.add_argument("-c", "--config", help="configuration file to use")
|
||
|
self.add_argument(
|
||
|
"-t", "--tag", default="0", help="tag for marking worker")
|
||
|
self.add_argument(
|
||
|
'-o',
|
||
|
'--override',
|
||
|
action='append',
|
||
|
default=[],
|
||
|
help='config options to be overridden')
|
||
|
self.add_argument(
|
||
|
"--style_image", default="examples/style_images/1.jpg", help="tag for marking worker")
|
||
|
self.add_argument(
|
||
|
"--text_corpus", default="PaddleOCR", help="tag for marking worker")
|
||
|
self.add_argument(
|
||
|
"--language", default="en", help="tag for marking worker")
|
||
|
|
||
|
def parse_args(self, argv=None):
|
||
|
args = super(ArgsParser, self).parse_args(argv)
|
||
|
assert args.config is not None, \
|
||
|
"Please specify --config=configure_file_path."
|
||
|
return args
|
||
|
|
||
|
|
||
|
def load_config(file_path):
|
||
|
"""
|
||
|
Load config from yml/yaml file.
|
||
|
Args:
|
||
|
file_path (str): Path of the config file to be loaded.
|
||
|
Returns: config
|
||
|
"""
|
||
|
ext = os.path.splitext(file_path)[1]
|
||
|
assert ext in ['.yml', '.yaml'], "only support yaml files for now"
|
||
|
with open(file_path, 'rb') as f:
|
||
|
config = yaml.load(f, Loader=yaml.Loader)
|
||
|
|
||
|
return config
|
||
|
|
||
|
|
||
|
def gen_config():
|
||
|
base_config = {
|
||
|
"Global": {
|
||
|
"algorithm": "SRNet",
|
||
|
"use_gpu": True,
|
||
|
"start_epoch": 1,
|
||
|
"stage1_epoch_num": 100,
|
||
|
"stage2_epoch_num": 100,
|
||
|
"log_smooth_window": 20,
|
||
|
"print_batch_step": 2,
|
||
|
"save_model_dir": "./output/SRNet",
|
||
|
"use_visualdl": False,
|
||
|
"save_epoch_step": 10,
|
||
|
"vgg_pretrain": "./pretrained/VGG19_pretrained",
|
||
|
"vgg_load_static_pretrain": True
|
||
|
},
|
||
|
"Architecture": {
|
||
|
"model_type": "data_aug",
|
||
|
"algorithm": "SRNet",
|
||
|
"net_g": {
|
||
|
"name": "srnet_net_g",
|
||
|
"encode_dim": 64,
|
||
|
"norm": "batch",
|
||
|
"use_dropout": False,
|
||
|
"init_type": "xavier",
|
||
|
"init_gain": 0.02,
|
||
|
"use_dilation": 1
|
||
|
},
|
||
|
# input_nc, ndf, netD,
|
||
|
# n_layers_D=3, norm='instance', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_id='cuda:0'
|
||
|
"bg_discriminator": {
|
||
|
"name": "srnet_bg_discriminator",
|
||
|
"input_nc": 6,
|
||
|
"ndf": 64,
|
||
|
"netD": "basic",
|
||
|
"norm": "none",
|
||
|
"init_type": "xavier",
|
||
|
},
|
||
|
"fusion_discriminator": {
|
||
|
"name": "srnet_fusion_discriminator",
|
||
|
"input_nc": 6,
|
||
|
"ndf": 64,
|
||
|
"netD": "basic",
|
||
|
"norm": "none",
|
||
|
"init_type": "xavier",
|
||
|
}
|
||
|
},
|
||
|
"Loss": {
|
||
|
"lamb": 10,
|
||
|
"perceptual_lamb": 1,
|
||
|
"muvar_lamb": 50,
|
||
|
"style_lamb": 500
|
||
|
},
|
||
|
"Optimizer": {
|
||
|
"name": "Adam",
|
||
|
"learning_rate": {
|
||
|
"name": "lambda",
|
||
|
"lr": 0.0002,
|
||
|
"lr_decay_iters": 50
|
||
|
},
|
||
|
"beta1": 0.5,
|
||
|
"beta2": 0.999,
|
||
|
},
|
||
|
"Train": {
|
||
|
"batch_size_per_card": 8,
|
||
|
"num_workers_per_card": 4,
|
||
|
"dataset": {
|
||
|
"delimiter": "\t",
|
||
|
"data_dir": "/",
|
||
|
"label_file": "tmp/label.txt",
|
||
|
"transforms": [{
|
||
|
"DecodeImage": {
|
||
|
"to_rgb": True,
|
||
|
"to_np": False,
|
||
|
"channel_first": False
|
||
|
}
|
||
|
}, {
|
||
|
"NormalizeImage": {
|
||
|
"scale": 1. / 255.,
|
||
|
"mean": [0.485, 0.456, 0.406],
|
||
|
"std": [0.229, 0.224, 0.225],
|
||
|
"order": None
|
||
|
}
|
||
|
}, {
|
||
|
"ToCHWImage": None
|
||
|
}]
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
with open("config.yml", "w") as f:
|
||
|
yaml.dump(base_config, f)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
gen_config()
|