修复一些导致不可用的bug
This commit is contained in:
parent
49958dca61
commit
4eba6c0dce
|
@ -23,12 +23,8 @@ __dir__ = os.path.dirname(os.path.abspath(__file__))
|
|||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
|
||||
|
||||
import paddle
|
||||
# paddle.manual_seed(2)
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
from ppocr.data import build_dataloader
|
||||
from ppocr.modeling import build_model
|
||||
from ppocr.modeling.architectures import build_model
|
||||
from ppocr.postprocess import build_post_process
|
||||
from ppocr.metrics import build_metric
|
||||
from ppocr.utils.save_load import init_model
|
||||
|
@ -39,8 +35,7 @@ import tools.program as program
|
|||
def main():
|
||||
global_config = config['Global']
|
||||
# build dataloader
|
||||
eval_loader, _ = build_dataloader(config['EVAL'], device, False,
|
||||
global_config)
|
||||
valid_dataloader = build_dataloader(config, 'Eval', device, logger)
|
||||
|
||||
# build post process
|
||||
post_process_class = build_post_process(config['PostProcess'],
|
||||
|
@ -63,16 +58,13 @@ def main():
|
|||
eval_class = build_metric(config['Metric'])
|
||||
|
||||
# start eval
|
||||
metirc = program.eval(model, eval_loader, post_process_class, eval_class)
|
||||
metirc = program.eval(model, valid_dataloader, post_process_class,
|
||||
eval_class)
|
||||
logger.info('metric eval ***************')
|
||||
for k, v in metirc.items():
|
||||
logger.info('{}:{}'.format(k, v))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
device, config = program.preprocess()
|
||||
paddle.disable_static(device)
|
||||
|
||||
logger = get_logger()
|
||||
print_dict(config, logger)
|
||||
config, device, logger, vdl_writer = program.preprocess()
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue