whl package support send model url
This commit is contained in:
parent
ac56eba7eb
commit
2f0740685e
|
@ -20,6 +20,7 @@ from tqdm import tqdm
|
|||
|
||||
from ppocr.utils.logging import get_logger
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
logger = get_logger()
|
||||
response = requests.get(url, stream=True)
|
||||
|
@ -45,6 +46,7 @@ def maybe_download(model_storage_directory, url):
|
|||
os.path.join(model_storage_directory, 'inference.pdiparams')
|
||||
) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'inference.pdmodel')):
|
||||
assert url.endswith('.tar'), 'Only supports tar compressed package'
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
|
@ -64,3 +66,17 @@ def maybe_download(model_storage_directory, url):
|
|||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def is_link(s):
|
||||
return s is not None and s.startswith('http')
|
||||
|
||||
|
||||
def confirm_model_dir_url(model_dir, default_model_dir, default_url):
|
||||
url = default_url
|
||||
if model_dir is None or is_link(model_dir):
|
||||
if is_link(model_dir):
|
||||
url = model_dir
|
||||
file_name = url.split('/')[-1][:-4]
|
||||
model_dir = default_model_dir
|
||||
model_dir = os.path.join(model_dir, file_name)
|
||||
return model_dir, url
|
||||
|
|
|
@ -30,7 +30,7 @@ from ppstructure.utility import init_args, draw_result
|
|||
|
||||
logger = get_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar
|
||||
from ppocr.utils.network import maybe_download, download_with_progressbar, confirm_model_dir_url, is_link
|
||||
|
||||
__all__ = ['PaddleStructure', 'draw_result', 'to_excel']
|
||||
|
||||
|
@ -70,16 +70,19 @@ class PaddleStructure(OCRSystem):
|
|||
logger.setLevel(logging.DEBUG)
|
||||
params.use_angle_cls = False
|
||||
# init model dir
|
||||
if params.det_model_dir is None:
|
||||
params.det_model_dir = os.path.join(BASE_DIR, VERSION, 'det')
|
||||
if params.rec_model_dir is None:
|
||||
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, 'rec')
|
||||
if params.structure_model_dir is None:
|
||||
params.structure_model_dir = os.path.join(BASE_DIR, VERSION, 'structure')
|
||||
params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'det'),
|
||||
model_urls['det'])
|
||||
params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'rec'),
|
||||
model_urls['rec'])
|
||||
params.structure_model_dir, structure_url = confirm_model_dir_url(params.structure_model_dir,
|
||||
os.path.join(BASE_DIR, VERSION, 'structure'),
|
||||
model_urls['structure'])
|
||||
# download model
|
||||
maybe_download(params.det_model_dir, model_urls['det'])
|
||||
maybe_download(params.rec_model_dir, model_urls['rec'])
|
||||
maybe_download(params.structure_model_dir, model_urls['structure'])
|
||||
maybe_download(params.det_model_dir, det_url)
|
||||
maybe_download(params.rec_model_dir, rec_url)
|
||||
maybe_download(params.structure_model_dir, structure_url)
|
||||
|
||||
if params.rec_char_dict_path is None:
|
||||
params.rec_char_type = 'EN'
|
||||
|
@ -143,3 +146,24 @@ def main():
|
|||
logger.info(item['res'])
|
||||
save_res(result, save_folder, img_name)
|
||||
logger.info('result save to {}'.format(os.path.join(save_folder, img_name)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
table_engine = PaddleStructure(
|
||||
output='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/output/table',
|
||||
show_log=True)
|
||||
|
||||
img_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR/ppstructure/test_imgs/paper-image.jpg'
|
||||
img = cv2.imread(img_path)
|
||||
result = table_engine(img)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
||||
from PIL import Image
|
||||
|
||||
font_path = '/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf'
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
im_show = draw_result(image, result,
|
||||
font_path='/Users/zhoujun20/Desktop/工作相关/table/table_pr/PaddleOCR//doc/fonts/simfang.ttf')
|
||||
im_show = Image.fromarray(im_show)
|
||||
im_show.save('result.jpg')
|
||||
|
|
Loading…
Reference in New Issue