mv download func to ppocr/utils/network.py
This commit is contained in:
parent
20466055b2
commit
a5f7511505
126
paddleocr.py
126
paddleocr.py
|
@ -21,15 +21,13 @@ sys.path.append(os.path.join(__dir__, ''))
|
|||
import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from tools.infer import predict_system
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||
from tools.infer.utility import draw_ocr, init_args, str2bool
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
|
@ -37,84 +35,84 @@ __all__ = ['PaddleOCR']
|
|||
model_urls = {
|
||||
'det': {
|
||||
'ch':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
|
||||
'en':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar'
|
||||
},
|
||||
'rec': {
|
||||
'ch': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
|
||||
},
|
||||
'en': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/en_dict.txt'
|
||||
},
|
||||
'french': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/french_dict.txt'
|
||||
},
|
||||
'german': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/german_dict.txt'
|
||||
},
|
||||
'korean': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/korean_dict.txt'
|
||||
},
|
||||
'japan': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/japan_dict.txt'
|
||||
},
|
||||
'chinese_cht': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
|
||||
},
|
||||
'ta': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ta_dict.txt'
|
||||
},
|
||||
'te': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/te_dict.txt'
|
||||
},
|
||||
'ka': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/ka_dict.txt'
|
||||
},
|
||||
'latin': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/latin_dict.txt'
|
||||
},
|
||||
'arabic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
|
||||
},
|
||||
'cyrillic': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
|
||||
},
|
||||
'devanagari': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
|
||||
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
|
||||
}
|
||||
},
|
||||
'cls':
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar'
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
|
@ -123,50 +121,6 @@ SUPPORT_REC_MODEL = ['CRNN']
|
|||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args(mMain=True):
|
||||
import argparse
|
||||
parser = init_args()
|
||||
|
@ -194,10 +148,10 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
args:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
postprocess_params = parse_args(mMain=False)
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = postprocess_params.use_angle_cls
|
||||
lang = postprocess_params.lang
|
||||
params = parse_args(mMain=False)
|
||||
params.__dict__.update(**kwargs)
|
||||
self.use_angle_cls = params.use_angle_cls
|
||||
lang = params.lang
|
||||
latin_lang = [
|
||||
'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
|
||||
'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
|
||||
|
@ -223,46 +177,46 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
lang = "devanagari"
|
||||
assert lang in model_urls[
|
||||
'rec'], 'param lang must in {}, but got {}'.format(
|
||||
model_urls['rec'].keys(), lang)
|
||||
model_urls['rec'].keys(), lang)
|
||||
if lang == "ch":
|
||||
det_lang = "ch"
|
||||
else:
|
||||
det_lang = "en"
|
||||
use_inner_dict = False
|
||||
if postprocess_params.rec_char_dict_path is None:
|
||||
if params.rec_char_dict_path is None:
|
||||
use_inner_dict = True
|
||||
postprocess_params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
params.rec_char_dict_path = model_urls['rec'][lang][
|
||||
'dict_path']
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
if params.det_model_dir is None:
|
||||
params.det_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'det', det_lang)
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
if params.rec_model_dir is None:
|
||||
params.rec_model_dir = os.path.join(BASE_DIR, VERSION,
|
||||
'rec', lang)
|
||||
if postprocess_params.cls_model_dir is None:
|
||||
postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
print(postprocess_params)
|
||||
if params.cls_model_dir is None:
|
||||
params.cls_model_dir = os.path.join(BASE_DIR, 'cls')
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir,
|
||||
maybe_download(params.det_model_dir,
|
||||
model_urls['det'][det_lang])
|
||||
maybe_download(postprocess_params.rec_model_dir,
|
||||
maybe_download(params.rec_model_dir,
|
||||
model_urls['rec'][lang]['url'])
|
||||
maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
|
||||
maybe_download(params.cls_model_dir, model_urls['cls'])
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
if params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
sys.exit(0)
|
||||
if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
if params.rec_algorithm not in SUPPORT_REC_MODEL:
|
||||
logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL))
|
||||
sys.exit(0)
|
||||
if use_inner_dict:
|
||||
postprocess_params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / postprocess_params.rec_char_dict_path)
|
||||
params.rec_char_dict_path = str(
|
||||
Path(__file__).parent / params.rec_char_dict_path)
|
||||
|
||||
print(params)
|
||||
# init det_model and rec_model
|
||||
super().__init__(postprocess_params)
|
||||
super().__init__(params)
|
||||
|
||||
def ocr(self, img, det=True, rec=True, cls=True):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,66 @@
|
|||
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
logger = get_logger()
|
||||
response = requests.get(url, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(save_path, 'wb') as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
|
||||
logger.error("Something went wrong while downloading models")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
tar_file_name_list = [
|
||||
'inference.pdiparams', 'inference.pdiparams.info', 'inference.pdmodel'
|
||||
]
|
||||
if not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
filename = None
|
||||
for tar_file_name in tar_file_name_list:
|
||||
if tar_file_name in member.name:
|
||||
filename = tar_file_name
|
||||
if filename is None:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
Loading…
Reference in New Issue