model_storage_directory参数取消,改为det_model_dir和rec_model_dir
This commit is contained in:
parent
09e15a684a
commit
c581ff51a8
133
paddleocr.py
133
paddleocr.py
|
@ -29,25 +29,19 @@ from tools.infer import predict_system
|
|||
from ppocr.utils.utility import initial_logger
|
||||
|
||||
logger = initial_logger()
|
||||
from ppocr.utils.utility import check_and_read_gif
|
||||
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
|
||||
|
||||
__all__ = ['PaddleOCR']
|
||||
|
||||
model_params = {
|
||||
'ch_det_mv3_db': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
|
||||
'algorithm': 'DB',
|
||||
},
|
||||
'ch_rec_mv3_crnn_enhance': {
|
||||
'url':
|
||||
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
|
||||
'algorithm': 'CRNN'
|
||||
},
|
||||
'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar',
|
||||
'rec':
|
||||
'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar',
|
||||
}
|
||||
|
||||
SUPPORT_DET_MODEL = ['DB']
|
||||
SUPPORT_REC_MODEL = ['Rosetta', 'CRNN', 'STARNet', 'RARE']
|
||||
SUPPORT_REC_MODEL = ['CRNN']
|
||||
BASE_DIR = os.path.expanduser("~/.paddleocr/")
|
||||
|
||||
|
||||
def download_with_progressbar(url, save_path):
|
||||
|
@ -65,34 +59,29 @@ def download_with_progressbar(url, save_path):
|
|||
sys.exit(0)
|
||||
|
||||
|
||||
def download_and_unzip(url, model_storage_directory):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for filename in tarObj.getnames():
|
||||
tarObj.extract(filename, model_storage_directory)
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def maybe_download(model_storage_directory, model_name, mode='det'):
|
||||
algorithm = None
|
||||
def maybe_download(model_storage_directory, url):
|
||||
# using custom model
|
||||
if os.path.exists(os.path.join(model_name, 'model')) and os.path.exists(
|
||||
os.path.join(model_name, 'params')):
|
||||
return model_name, algorithm
|
||||
# using the model of paddleocr
|
||||
model_path = os.path.join(model_storage_directory, model_name)
|
||||
if not os.path.exists(os.path.join(model_path,
|
||||
'model')) or not os.path.exists(
|
||||
os.path.join(model_path, 'params')):
|
||||
assert model_name in model_params, 'model must in {}'.format(
|
||||
model_params.keys())
|
||||
download_and_unzip(model_params[model_name]['url'],
|
||||
model_storage_directory)
|
||||
algorithm = model_params[model_name]['algorithm']
|
||||
return model_path, algorithm
|
||||
if not os.path.exists(os.path.join(
|
||||
model_storage_directory, 'model')) or not os.path.exists(
|
||||
os.path.join(model_storage_directory, 'params')):
|
||||
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
|
||||
print('download {} to {}'.format(url, tmp_path))
|
||||
os.makedirs(model_storage_directory, exist_ok=True)
|
||||
download_with_progressbar(url, tmp_path)
|
||||
with tarfile.open(tmp_path, 'r') as tarObj:
|
||||
for member in tarObj.getmembers():
|
||||
if "model" in member.name:
|
||||
filename = 'model'
|
||||
elif "params" in member.name:
|
||||
filename = 'params'
|
||||
else:
|
||||
continue
|
||||
file = tarObj.extractfile(member)
|
||||
with open(
|
||||
os.path.join(model_storage_directory, filename),
|
||||
'wb') as f:
|
||||
f.write(file.read())
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -111,7 +100,7 @@ def parse_args():
|
|||
# params for text detector
|
||||
parser.add_argument("--image_dir", type=str)
|
||||
parser.add_argument("--det_algorithm", type=str, default='DB')
|
||||
parser.add_argument("--det_model_name", type=str, default='ch_det_mv3_db')
|
||||
parser.add_argument("--det_model_dir", type=str, default=None)
|
||||
parser.add_argument("--det_max_side_len", type=float, default=960)
|
||||
|
||||
# DB parmas
|
||||
|
@ -126,11 +115,11 @@ def parse_args():
|
|||
|
||||
# params for text recognizer
|
||||
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
|
||||
parser.add_argument(
|
||||
"--rec_model_name", type=str, default='ch_rec_mv3_crnn_enhance')
|
||||
parser.add_argument("--rec_model_dir", type=str, default=None)
|
||||
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
|
||||
parser.add_argument("--rec_char_type", type=str, default='ch')
|
||||
parser.add_argument("--rec_batch_num", type=int, default=30)
|
||||
parser.add_argument("--max_text_length", type=int, default=25)
|
||||
parser.add_argument(
|
||||
"--rec_char_dict_path",
|
||||
type=str,
|
||||
|
@ -138,53 +127,30 @@ def parse_args():
|
|||
parser.add_argument("--use_space_char", type=bool, default=True)
|
||||
parser.add_argument("--enable_mkldnn", type=bool, default=False)
|
||||
|
||||
parser.add_argument("--model_storage_directory", type=str, default=False)
|
||||
parser.add_argument("--det", type=str2bool, default=True)
|
||||
parser.add_argument("--rec", type=str2bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
class PaddleOCR(predict_system.TextSystem):
|
||||
def __init__(self,
|
||||
det_model_name='ch_det_mv3_db',
|
||||
rec_model_name='ch_rec_mv3_crnn_enhance',
|
||||
model_storage_directory=None,
|
||||
log_level=20,
|
||||
**kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
paddleocr package
|
||||
args:
|
||||
det_model_name: det_model name, keep same with filename in paddleocr. default is ch_det_mv3_db
|
||||
det_model_name: rec_model name, keep same with filename in paddleocr. default is ch_rec_mv3_crnn_enhance
|
||||
model_storage_directory: model save path. default is ~/.paddleocr
|
||||
det model will save to model_storage_directory/det_model
|
||||
rec model will save to model_storage_directory/rec_model
|
||||
log_level:
|
||||
**kwargs: other params show in paddleocr --help
|
||||
"""
|
||||
logger.setLevel(log_level)
|
||||
postprocess_params = parse_args()
|
||||
# init model dir
|
||||
if model_storage_directory:
|
||||
self.model_storage_directory = model_storage_directory
|
||||
else:
|
||||
self.model_storage_directory = os.path.expanduser(
|
||||
"~/.paddleocr/") + '/model'
|
||||
Path(self.model_storage_directory).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# download model
|
||||
det_model_path, det_algorithm = maybe_download(
|
||||
self.model_storage_directory, det_model_name, 'det')
|
||||
rec_model_path, rec_algorithm = maybe_download(
|
||||
self.model_storage_directory, rec_model_name, 'rec')
|
||||
# update model and post_process params
|
||||
postprocess_params.__dict__.update(**kwargs)
|
||||
postprocess_params.det_model_dir = det_model_path
|
||||
postprocess_params.rec_model_dir = rec_model_path
|
||||
if det_algorithm is not None:
|
||||
postprocess_params.det_algorithm = det_algorithm
|
||||
if rec_algorithm is not None:
|
||||
postprocess_params.rec_algorithm = rec_algorithm
|
||||
|
||||
# init model dir
|
||||
if postprocess_params.det_model_dir is None:
|
||||
postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det')
|
||||
if postprocess_params.rec_model_dir is None:
|
||||
postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec')
|
||||
print(postprocess_params)
|
||||
# download model
|
||||
maybe_download(postprocess_params.det_model_dir, model_params['det'])
|
||||
maybe_download(postprocess_params.rec_model_dir, model_params['rec'])
|
||||
|
||||
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
|
||||
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
|
||||
|
@ -229,3 +195,18 @@ class PaddleOCR(predict_system.TextSystem):
|
|||
img = [img]
|
||||
rec_res, elapse = self.text_recognizer(img)
|
||||
return rec_res
|
||||
|
||||
|
||||
def main():
|
||||
# for com
|
||||
args = parse_args()
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
if len(image_file_list) == 0:
|
||||
logger.error('no images find in {}'.format(args.image_dir))
|
||||
return
|
||||
ocr_engine = PaddleOCR()
|
||||
for img_path in image_file_list:
|
||||
print(img_path)
|
||||
result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec)
|
||||
for line in result:
|
||||
print(line)
|
||||
|
|
Loading…
Reference in New Issue