修复一些导致不可用的bug

This commit is contained in:
WenmuZhou 2020-11-09 13:28:15 +08:00
parent 49958dca61
commit 4eba6c0dce
1 changed files with 5 additions and 13 deletions

View File

@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
import paddle
# paddle.manual_seed(2)
from ppocr.utils.logging import get_logger
from ppocr.data import build_dataloader from ppocr.data import build_dataloader
from ppocr.modeling import build_model from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
@ -39,8 +35,7 @@ import tools.program as program
def main(): def main():
global_config = config['Global'] global_config = config['Global']
# build dataloader # build dataloader
eval_loader, _ = build_dataloader(config['EVAL'], device, False, valid_dataloader = build_dataloader(config, 'Eval', device, logger)
global_config)
# build post process # build post process
post_process_class = build_post_process(config['PostProcess'], post_process_class = build_post_process(config['PostProcess'],
@ -63,16 +58,13 @@ def main():
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# start eval # start eval
metirc = program.eval(model, eval_loader, post_process_class, eval_class) metirc = program.eval(model, valid_dataloader, post_process_class,
eval_class)
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metirc.items(): for k, v in metirc.items():
logger.info('{}:{}'.format(k, v)) logger.info('{}:{}'.format(k, v))
if __name__ == '__main__': if __name__ == '__main__':
device, config = program.preprocess() config, device, logger, vdl_writer = program.preprocess()
paddle.disable_static(device)
logger = get_logger()
print_dict(config, logger)
main() main()