support model link in model_dir params

This commit is contained in:
WenmuZhou 2021-06-10 14:47:23 +08:00
parent 037e17fc82
commit 7d47283128
1 changed files with 14 additions and 15 deletions

View File

@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR']
@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem):
'dict_path']
# init model dir
if params.det_model_dir is None:
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
'det', det_lang)
if params.rec_model_dir is None:
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
'rec', lang)
if params.cls_model_dir is None:
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
# download model
maybe_download(params.det_model_dir,
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'det', det_lang),
model_urls['det'][det_lang])
maybe_download(params.rec_model_dir,
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'rec', lang),
model_urls['rec'][lang]['url'])
maybe_download(params.cls_model_dir, model_urls['cls'])
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.cls_model_dir, cls_url)
if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
@ -277,7 +276,7 @@ def main():
# for cmd
args = parse_args(mMain=True)
image_dir = args.image_dir
if image_dir.startswith('http'):
if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg']
else: