| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  |  | # | 
					
						
							|  |  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  |  | # limitations under the License. | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import os | 
					
						
							|  |  |  |  | import sys | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | __dir__ = os.path.dirname(__file__) | 
					
						
							|  |  |  |  | sys.path.append(os.path.join(__dir__, '')) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import cv2 | 
					
						
							|  |  |  |  | import numpy as np | 
					
						
							|  |  |  |  | from pathlib import Path | 
					
						
							|  |  |  |  | import tarfile | 
					
						
							|  |  |  |  | import requests | 
					
						
							|  |  |  |  | from tqdm import tqdm | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from tools.infer import predict_system | 
					
						
							|  |  |  |  | from ppocr.utils.utility import initial_logger | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | logger = initial_logger() | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  | from ppocr.utils.utility import check_and_read_gif, get_image_file_list | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | __all__ = ['PaddleOCR'] | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | model_params = { | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', | 
					
						
							|  |  |  |  |     'rec': | 
					
						
							|  |  |  |  |     'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | SUPPORT_DET_MODEL = ['DB'] | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  | SUPPORT_REC_MODEL = ['CRNN'] | 
					
						
							|  |  |  |  | BASE_DIR = os.path.expanduser("~/.paddleocr/") | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def download_with_progressbar(url, save_path): | 
					
						
							|  |  |  |  |     response = requests.get(url, stream=True) | 
					
						
							|  |  |  |  |     total_size_in_bytes = int(response.headers.get('content-length', 0)) | 
					
						
							|  |  |  |  |     block_size = 1024  # 1 Kibibyte | 
					
						
							|  |  |  |  |     progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) | 
					
						
							|  |  |  |  |     with open(save_path, 'wb') as file: | 
					
						
							|  |  |  |  |         for data in response.iter_content(block_size): | 
					
						
							|  |  |  |  |             progress_bar.update(len(data)) | 
					
						
							|  |  |  |  |             file.write(data) | 
					
						
							|  |  |  |  |     progress_bar.close() | 
					
						
							|  |  |  |  |     if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: | 
					
						
							|  |  |  |  |         logger.error("ERROR, something went wrong") | 
					
						
							|  |  |  |  |         sys.exit(0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  | def maybe_download(model_storage_directory, url): | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |     # using custom model | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     if not os.path.exists(os.path.join( | 
					
						
							|  |  |  |  |             model_storage_directory, 'model')) or not os.path.exists( | 
					
						
							|  |  |  |  |                 os.path.join(model_storage_directory, 'params')): | 
					
						
							|  |  |  |  |         tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) | 
					
						
							|  |  |  |  |         print('download {} to {}'.format(url, tmp_path)) | 
					
						
							|  |  |  |  |         os.makedirs(model_storage_directory, exist_ok=True) | 
					
						
							|  |  |  |  |         download_with_progressbar(url, tmp_path) | 
					
						
							|  |  |  |  |         with tarfile.open(tmp_path, 'r') as tarObj: | 
					
						
							|  |  |  |  |             for member in tarObj.getmembers(): | 
					
						
							|  |  |  |  |                 if "model" in member.name: | 
					
						
							|  |  |  |  |                     filename = 'model' | 
					
						
							|  |  |  |  |                 elif "params" in member.name: | 
					
						
							|  |  |  |  |                     filename = 'params' | 
					
						
							|  |  |  |  |                 else: | 
					
						
							|  |  |  |  |                     continue | 
					
						
							|  |  |  |  |                 file = tarObj.extractfile(member) | 
					
						
							|  |  |  |  |                 with open( | 
					
						
							|  |  |  |  |                         os.path.join(model_storage_directory, filename), | 
					
						
							|  |  |  |  |                         'wb') as f: | 
					
						
							|  |  |  |  |                     f.write(file.read()) | 
					
						
							|  |  |  |  |         os.remove(tmp_path) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def parse_args(): | 
					
						
							|  |  |  |  |     import argparse | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def str2bool(v): | 
					
						
							|  |  |  |  |         return v.lower() in ("true", "t", "1") | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							|  |  |  |  |     # params for prediction engine | 
					
						
							|  |  |  |  |     parser.add_argument("--use_gpu", type=str2bool, default=True) | 
					
						
							|  |  |  |  |     parser.add_argument("--ir_optim", type=str2bool, default=True) | 
					
						
							|  |  |  |  |     parser.add_argument("--use_tensorrt", type=str2bool, default=False) | 
					
						
							|  |  |  |  |     parser.add_argument("--gpu_mem", type=int, default=8000) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # params for text detector | 
					
						
							|  |  |  |  |     parser.add_argument("--image_dir", type=str) | 
					
						
							|  |  |  |  |     parser.add_argument("--det_algorithm", type=str, default='DB') | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     parser.add_argument("--det_model_dir", type=str, default=None) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |     parser.add_argument("--det_max_side_len", type=float, default=960) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # DB parmas | 
					
						
							|  |  |  |  |     parser.add_argument("--det_db_thresh", type=float, default=0.3) | 
					
						
							|  |  |  |  |     parser.add_argument("--det_db_box_thresh", type=float, default=0.5) | 
					
						
							|  |  |  |  |     parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # EAST parmas | 
					
						
							|  |  |  |  |     parser.add_argument("--det_east_score_thresh", type=float, default=0.8) | 
					
						
							|  |  |  |  |     parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) | 
					
						
							|  |  |  |  |     parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # params for text recognizer | 
					
						
							|  |  |  |  |     parser.add_argument("--rec_algorithm", type=str, default='CRNN') | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     parser.add_argument("--rec_model_dir", type=str, default=None) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |     parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") | 
					
						
							|  |  |  |  |     parser.add_argument("--rec_char_type", type=str, default='ch') | 
					
						
							|  |  |  |  |     parser.add_argument("--rec_batch_num", type=int, default=30) | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     parser.add_argument("--max_text_length", type=int, default=25) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |     parser.add_argument( | 
					
						
							|  |  |  |  |         "--rec_char_dict_path", | 
					
						
							|  |  |  |  |         type=str, | 
					
						
							|  |  |  |  |         default="./ppocr/utils/ppocr_keys_v1.txt") | 
					
						
							|  |  |  |  |     parser.add_argument("--use_space_char", type=bool, default=True) | 
					
						
							|  |  |  |  |     parser.add_argument("--enable_mkldnn", type=bool, default=False) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     parser.add_argument("--det", type=str2bool, default=True) | 
					
						
							|  |  |  |  |     parser.add_argument("--rec", type=str2bool, default=True) | 
					
						
							| 
									
										
										
										
											2020-08-27 15:10:41 +08:00
										 |  |  |  |     parser.add_argument("--use_zero_copy_run", type=bool, default=False) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |     return parser.parse_args() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | class PaddleOCR(predict_system.TextSystem): | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |     def __init__(self, **kwargs): | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |         """
 | 
					
						
							|  |  |  |  |         paddleocr package | 
					
						
							|  |  |  |  |         args: | 
					
						
							|  |  |  |  |             **kwargs: other params show in paddleocr --help | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         postprocess_params = parse_args() | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |         postprocess_params.__dict__.update(**kwargs) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |         # init model dir | 
					
						
							|  |  |  |  |         if postprocess_params.det_model_dir is None: | 
					
						
							|  |  |  |  |             postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') | 
					
						
							|  |  |  |  |         if postprocess_params.rec_model_dir is None: | 
					
						
							|  |  |  |  |             postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') | 
					
						
							|  |  |  |  |         print(postprocess_params) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  |         # download model | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  |         maybe_download(postprocess_params.det_model_dir, model_params['det']) | 
					
						
							|  |  |  |  |         maybe_download(postprocess_params.rec_model_dir, model_params['rec']) | 
					
						
							| 
									
										
										
										
											2020-08-22 19:42:14 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  |         if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: | 
					
						
							|  |  |  |  |             logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) | 
					
						
							|  |  |  |  |             sys.exit(0) | 
					
						
							|  |  |  |  |         if postprocess_params.rec_algorithm not in SUPPORT_REC_MODEL: | 
					
						
							|  |  |  |  |             logger.error('rec_algorithm must in {}'.format(SUPPORT_REC_MODEL)) | 
					
						
							|  |  |  |  |             sys.exit(0) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         postprocess_params.rec_char_dict_path = Path( | 
					
						
							|  |  |  |  |             __file__).parent / postprocess_params.rec_char_dict_path | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |         # init det_model and rec_model | 
					
						
							|  |  |  |  |         super().__init__(postprocess_params) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     def ocr(self, img, det=True, rec=True): | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         ocr with paddleocr | 
					
						
							|  |  |  |  |         args: | 
					
						
							|  |  |  |  |             img: img for ocr, support ndarray, img_path and list or ndarray | 
					
						
							|  |  |  |  |             det: use text detection or not, if false, only rec will be exec. default is True | 
					
						
							|  |  |  |  |             rec: use text recognition or not, if false, only det will be exec. default is True | 
					
						
							|  |  |  |  |         """
 | 
					
						
							|  |  |  |  |         assert isinstance(img, (np.ndarray, list, str)) | 
					
						
							|  |  |  |  |         if isinstance(img, str): | 
					
						
							|  |  |  |  |             image_file = img | 
					
						
							|  |  |  |  |             img, flag = check_and_read_gif(image_file) | 
					
						
							|  |  |  |  |             if not flag: | 
					
						
							|  |  |  |  |                 img = cv2.imread(image_file) | 
					
						
							|  |  |  |  |             if img is None: | 
					
						
							|  |  |  |  |                 logger.error("error in loading image:{}".format(image_file)) | 
					
						
							|  |  |  |  |                 return None | 
					
						
							|  |  |  |  |         if det and rec: | 
					
						
							|  |  |  |  |             dt_boxes, rec_res = self.__call__(img) | 
					
						
							|  |  |  |  |             return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] | 
					
						
							|  |  |  |  |         elif det and not rec: | 
					
						
							|  |  |  |  |             dt_boxes, elapse = self.text_detector(img) | 
					
						
							|  |  |  |  |             if dt_boxes is None: | 
					
						
							|  |  |  |  |                 return None | 
					
						
							|  |  |  |  |             return [box.tolist() for box in dt_boxes] | 
					
						
							|  |  |  |  |         else: | 
					
						
							|  |  |  |  |             if not isinstance(img, list): | 
					
						
							|  |  |  |  |                 img = [img] | 
					
						
							|  |  |  |  |             rec_res, elapse = self.text_recognizer(img) | 
					
						
							|  |  |  |  |             return rec_res | 
					
						
							| 
									
										
										
										
											2020-08-24 11:30:56 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | def main(): | 
					
						
							|  |  |  |  |     # for com | 
					
						
							|  |  |  |  |     args = parse_args() | 
					
						
							|  |  |  |  |     image_file_list = get_image_file_list(args.image_dir) | 
					
						
							|  |  |  |  |     if len(image_file_list) == 0: | 
					
						
							|  |  |  |  |         logger.error('no images find in {}'.format(args.image_dir)) | 
					
						
							|  |  |  |  |         return | 
					
						
							|  |  |  |  |     ocr_engine = PaddleOCR() | 
					
						
							|  |  |  |  |     for img_path in image_file_list: | 
					
						
							|  |  |  |  |         print(img_path) | 
					
						
							|  |  |  |  |         result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) | 
					
						
							|  |  |  |  |         for line in result: | 
					
						
							| 
									
										
										
										
											2020-08-27 15:10:41 +08:00
										 |  |  |  |             print(line) |