support model link in model_dir params
This commit is contained in:
parent
037e17fc82
commit
7d47283128
29
paddleocr.py
29
paddleocr.py
|
@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger
|
|||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
|
||||
from tools.infer.utility import draw_ocr, init_args, str2bool
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
|
@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
'dict_path']
|
||||
|
||||
# init model dir
|
||||
if params.det_model_dir is None:
|
||||
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'det', det_lang)
|
||||
if params.rec_model_dir is None:
|
||||
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'rec', lang)
|
||||
if params.cls_model_dir is None:
|
||||
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'det', det_lang),
|
||||
model_urls['det'][det_lang])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'rec', lang),
|
||||
model_urls['rec'][lang]['url'])
|
||||
params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'cls'),
|
||||
model_urls['cls'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir,
|
||||
model_urls['det'][det_lang])
|
||||
maybe_download(params.rec_model_dir,
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(params.cls_model_dir, model_urls['cls'])
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.cls_model_dir, cls_url)
|
||||
|
||||
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
|
@ -277,7 +276,7 @@ def main():
|
|||
# for cmd
|
||||
args = parse_args(mMain=True)
|
||||
image_dir = args.image_dir
|
||||
if image_dir.startswith('http'):
|
||||
if is_link(image_dir):
|
||||
download_with_progressbar(image_dir, 'tmp.jpg')
|
||||
image_file_list = ['tmp.jpg']
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue