add pse to windows_not_support_list
This commit is contained in:
parent
ac98415b48
commit
b6a9f5d2da
|
@ -18,8 +18,10 @@ from __future__ import print_function
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import copy
|
||||
import platform
|
||||
|
||||
__all__ = ['build_post_process']
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
from .db_postprocess import DBPostProcess, DistillationDBPostProcess
|
||||
from .east_postprocess import EASTPostProcess
|
||||
|
@ -28,7 +30,10 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, Di
|
|||
TableLabelDecode, SARLabelDecode
|
||||
from .cls_postprocess import ClsPostProcess
|
||||
from .pg_postprocess import PGPostProcess
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
if platform.system() != "Windows":
|
||||
# pse is not support in Windows
|
||||
from .pse_postprocess import PSEPostProcess
|
||||
|
||||
|
||||
def build_post_process(config, global_config=None):
|
||||
|
|
|
@ -395,20 +395,6 @@ def preprocess(is_train=False):
|
|||
config = load_config(FLAGS.config)
|
||||
merge_config(FLAGS.opt)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
|
||||
]
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
if is_train:
|
||||
# save_config
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
@ -420,6 +406,27 @@ def preprocess(is_train=False):
|
|||
else:
|
||||
log_file = None
|
||||
logger = get_logger(name='root', log_file=log_file)
|
||||
|
||||
# check if set use_gpu=True in paddlepaddle cpu version
|
||||
use_gpu = config['Global']['use_gpu']
|
||||
check_gpu(use_gpu)
|
||||
|
||||
alg = config['Architecture']['algorithm']
|
||||
assert alg in [
|
||||
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
|
||||
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE'
|
||||
]
|
||||
windows_not_support_list = ['PSE']
|
||||
if platform.system() == "Windows" and alg in windows_not_support_list:
|
||||
logger.warning('{} is not support in Windows now'.format(
|
||||
windows_not_support_list))
|
||||
sys.exit()
|
||||
|
||||
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
|
||||
device = paddle.set_device(device)
|
||||
|
||||
config['Global']['distributed'] = dist.get_world_size() != 1
|
||||
|
||||
if config['Global']['use_visualdl']:
|
||||
from visualdl import LogWriter
|
||||
save_model_dir = config['Global']['save_model_dir']
|
||||
|
|
Loading…
Reference in New Issue