support model link in model_dir params

This commit is contained in:
WenmuZhou 2021-06-10 14:47:23 +08:00
parent 037e17fc82
commit 7d47283128
1 changed files with 14 additions and 15 deletions

View File

@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem):
'dict_path'] 'dict_path']
# init model dir # init model dir
if params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'det', det_lang),
'det', det_lang) model_urls['det'][det_lang])
if params.rec_model_dir is None: params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'rec', lang),
'rec', lang) model_urls['rec'][lang]['url'])
if params.cls_model_dir is None: params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
params.cls_model_dir = os.path.join(BASE_DIR, 'cls') os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model # download model
maybe_download(params.det_model_dir, maybe_download(params.det_model_dir, det_url)
model_urls['det'][det_lang]) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.rec_model_dir, maybe_download(params.cls_model_dir, cls_url)
model_urls['rec'][lang]['url'])
maybe_download(params.cls_model_dir, model_urls['cls'])
if params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
@ -277,7 +276,7 @@ def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else: