diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index 3eb5e28d..80d926a2 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -18,8 +18,10 @@ from __future__ import print_function from __future__ import unicode_literals import copy +import platform __all__ = ['build_post_process'] +from ppocr.utils.logging import get_logger from .db_postprocess import DBPostProcess, DistillationDBPostProcess from .east_postprocess import EASTPostProcess @@ -28,7 +30,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di TableLabelDecode, SARLabelDecode from .cls_postprocess import ClsPostProcess from .pg_postprocess import PGPostProcess -from .pse_postprocess import PSEPostProcess + +if platform.system() != "Windows": + # pse is not support in Windows + from .pse_postprocess import PSEPostProcess def build_post_process(config, global_config=None): diff --git a/tools/program.py b/tools/program.py index f484cf4a..10eb246a 100755 --- a/tools/program.py +++ b/tools/program.py @@ -395,20 +395,6 @@ def preprocess(is_train=False): config = load_config(FLAGS.config) merge_config(FLAGS.opt) - # check if set use_gpu=True in paddlepaddle cpu version - use_gpu = config['Global']['use_gpu'] - check_gpu(use_gpu) - - alg = config['Architecture']['algorithm'] - assert alg in [ - 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', - 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' - ] - - device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' - device = paddle.set_device(device) - - config['Global']['distributed'] = dist.get_world_size() != 1 if is_train: # save_config save_model_dir = config['Global']['save_model_dir'] @@ -420,6 +406,27 @@ def preprocess(is_train=False): else: log_file = None logger = get_logger(name='root', log_file=log_file) + + # check if set use_gpu=True in paddlepaddle cpu version + use_gpu = config['Global']['use_gpu'] + check_gpu(use_gpu) + + alg = config['Architecture']['algorithm'] + assert alg in [ + 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN', + 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE' + ] + windows_not_support_list = ['PSE'] + if platform.system() == "Windows" and alg in windows_not_support_list: + logger.warning('{} is not support in Windows now'.format( + windows_not_support_list)) + sys.exit() + + device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu' + device = paddle.set_device(device) + + config['Global']['distributed'] = dist.get_world_size() != 1 + if config['Global']['use_visualdl']: from visualdl import LogWriter save_model_dir = config['Global']['save_model_dir']