add pse to windows_not_support_list

This commit is contained in:
WenmuZhou 2021-09-27 19:43:36 +08:00
parent ac98415b48
commit b6a9f5d2da
2 changed files with 27 additions and 15 deletions

View File

@ -18,8 +18,10 @@ from __future__ import print_function
from __future__ import unicode_literals
import copy
import platform
__all__ = ['build_post_process']
from ppocr.utils.logging import get_logger
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
@ -28,7 +30,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
TableLabelDecode, SARLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .pse_postprocess import PSEPostProcess
if platform.system() != "Windows":
# pse is not support in Windows
from .pse_postprocess import PSEPostProcess
def build_post_process(config, global_config=None):

View File

@ -395,20 +395,6 @@ def preprocess(is_train=False):
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
]
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if is_train:
# save_config
save_model_dir = config['Global']['save_model_dir']
@ -420,6 +406,27 @@ def preprocess(is_train=False):
else:
log_file = None
logger = get_logger(name='root', log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
check_gpu(use_gpu)
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
]
windows_not_support_list = ['PSE']
if platform.system() == "Windows" and alg in windows_not_support_list:
logger.warning('{} is not support in Windows now'.format(
windows_not_support_list))
sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
if config['Global']['use_visualdl']:
from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']