fix some errors and bugs (#1185)

fix some errors in pip
This commit is contained in:
shaohua.zhang 2020-11-21 22:11:53 +08:00 committed by GitHub
parent 0361647094
commit a9d709315a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 12 deletions

View File

@ -87,8 +87,8 @@ def download_with_progressbar(url, save_path):
progress_bar.update(len(data)) progress_bar.update(len(data))
file.write(data) file.write(data)
progress_bar.close() progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
logger.error("ERROR, something went wrong") logger.error("Something went wrong while downloading models")
sys.exit(0) sys.exit(0)
@ -157,7 +157,6 @@ def parse_args():
parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--use_space_char", type=bool, default=True)
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str, default=None) parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180']) parser.add_argument("--label_list", type=list, default=['0', '180'])
@ -171,7 +170,7 @@ def parse_args():
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--cls", type=str2bool, default=False) parser.add_argument("--use_angle_cls", type=str2bool, default=True)
return parser.parse_args() return parser.parse_args()
@ -206,7 +205,6 @@ class PaddleOCR(predict_system.TextSystem):
maybe_download(postprocess_params.det_model_dir, model_urls['det']) maybe_download(postprocess_params.det_model_dir, model_urls['det'])
maybe_download(postprocess_params.rec_model_dir, maybe_download(postprocess_params.rec_model_dir,
model_urls['rec'][lang]['url']) model_urls['rec'][lang]['url'])
if self.use_angle_cls:
maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) maybe_download(postprocess_params.cls_model_dir, model_urls['cls'])
if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL:
@ -231,9 +229,6 @@ class PaddleOCR(predict_system.TextSystem):
rec: use text recognition or not, if false, only det will be exec. default is True rec: use text recognition or not, if false, only det will be exec. default is True
""" """
assert isinstance(img, (np.ndarray, list, str)) assert isinstance(img, (np.ndarray, list, str))
if cls and not self.use_angle_cls:
print('cls should be false when use_angle_cls is false')
exit(-1)
self.use_angle_cls = cls self.use_angle_cls = cls
if isinstance(img, str): if isinstance(img, str):
image_file = img image_file = img
@ -275,6 +270,7 @@ def main():
result = ocr_engine.ocr(img_path, result = ocr_engine.ocr(img_path,
det=args.det, det=args.det,
rec=args.rec, rec=args.rec,
cls=args.cls) cls=args.use_angle_cls)
if result is not None:
for line in result: for line in result:
print(line) print(line)