mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 17:59:11 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			302 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			11 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| import os
 | |
| import sys
 | |
| 
 | |
| __dir__ = os.path.dirname(os.path.abspath(__file__))
 | |
| sys.path.append(__dir__)
 | |
| sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
 | |
| 
 | |
| os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
 | |
| 
 | |
| import cv2
 | |
| import numpy as np
 | |
| import time
 | |
| import sys
 | |
| 
 | |
| import tools.infer.utility as utility
 | |
| from ppocr.utils.logging import get_logger
 | |
| from ppocr.utils.utility import get_image_file_list, check_and_read_gif
 | |
| from ppocr.data import create_operators, transform
 | |
| from ppocr.postprocess import build_post_process
 | |
| import json
 | |
| logger = get_logger()
 | |
| 
 | |
| 
 | |
| class TextDetector(object):
 | |
|     def __init__(self, args):
 | |
|         self.args = args
 | |
|         self.det_algorithm = args.det_algorithm
 | |
|         self.use_onnx = args.use_onnx
 | |
|         pre_process_list = [{
 | |
|             'DetResizeForTest': {
 | |
|                 'limit_side_len': args.det_limit_side_len,
 | |
|                 'limit_type': args.det_limit_type,
 | |
|             }
 | |
|         }, {
 | |
|             'NormalizeImage': {
 | |
|                 'std': [0.229, 0.224, 0.225],
 | |
|                 'mean': [0.485, 0.456, 0.406],
 | |
|                 'scale': '1./255.',
 | |
|                 'order': 'hwc'
 | |
|             }
 | |
|         }, {
 | |
|             'ToCHWImage': None
 | |
|         }, {
 | |
|             'KeepKeys': {
 | |
|                 'keep_keys': ['image', 'shape']
 | |
|             }
 | |
|         }]
 | |
|         postprocess_params = {}
 | |
|         if self.det_algorithm == "DB":
 | |
|             postprocess_params['name'] = 'DBPostProcess'
 | |
|             postprocess_params["thresh"] = args.det_db_thresh
 | |
|             postprocess_params["box_thresh"] = args.det_db_box_thresh
 | |
|             postprocess_params["max_candidates"] = 1000
 | |
|             postprocess_params["unclip_ratio"] = args.det_db_unclip_ratio
 | |
|             postprocess_params["use_dilation"] = args.use_dilation
 | |
|             postprocess_params["score_mode"] = args.det_db_score_mode
 | |
|         elif self.det_algorithm == "EAST":
 | |
|             postprocess_params['name'] = 'EASTPostProcess'
 | |
|             postprocess_params["score_thresh"] = args.det_east_score_thresh
 | |
|             postprocess_params["cover_thresh"] = args.det_east_cover_thresh
 | |
|             postprocess_params["nms_thresh"] = args.det_east_nms_thresh
 | |
|         elif self.det_algorithm == "SAST":
 | |
|             pre_process_list[0] = {
 | |
|                 'DetResizeForTest': {
 | |
|                     'resize_long': args.det_limit_side_len
 | |
|                 }
 | |
|             }
 | |
|             postprocess_params['name'] = 'SASTPostProcess'
 | |
|             postprocess_params["score_thresh"] = args.det_sast_score_thresh
 | |
|             postprocess_params["nms_thresh"] = args.det_sast_nms_thresh
 | |
|             self.det_sast_polygon = args.det_sast_polygon
 | |
|             if self.det_sast_polygon:
 | |
|                 postprocess_params["sample_pts_num"] = 6
 | |
|                 postprocess_params["expand_scale"] = 1.2
 | |
|                 postprocess_params["shrink_ratio_of_width"] = 0.2
 | |
|             else:
 | |
|                 postprocess_params["sample_pts_num"] = 2
 | |
|                 postprocess_params["expand_scale"] = 1.0
 | |
|                 postprocess_params["shrink_ratio_of_width"] = 0.3
 | |
|         elif self.det_algorithm == "PSE":
 | |
|             postprocess_params['name'] = 'PSEPostProcess'
 | |
|             postprocess_params["thresh"] = args.det_pse_thresh
 | |
|             postprocess_params["box_thresh"] = args.det_pse_box_thresh
 | |
|             postprocess_params["min_area"] = args.det_pse_min_area
 | |
|             postprocess_params["box_type"] = args.det_pse_box_type
 | |
|             postprocess_params["scale"] = args.det_pse_scale
 | |
|             self.det_pse_box_type = args.det_pse_box_type
 | |
|         else:
 | |
|             logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
 | |
|             sys.exit(0)
 | |
| 
 | |
|         self.preprocess_op = create_operators(pre_process_list)
 | |
|         self.postprocess_op = build_post_process(postprocess_params)
 | |
|         self.predictor, self.input_tensor, self.output_tensors, self.config = utility.create_predictor(
 | |
|             args, 'det', logger)
 | |
| 
 | |
|         if self.use_onnx:
 | |
|             img_h, img_w = self.input_tensor.shape[2:]
 | |
|             if img_h is not None and img_w is not None and img_h > 0 and img_w > 0:
 | |
|                 pre_process_list[0] = {
 | |
|                     'DetResizeForTest': {
 | |
|                         'image_shape': [img_h, img_w]
 | |
|                     }
 | |
|                 }
 | |
|         self.preprocess_op = create_operators(pre_process_list)
 | |
| 
 | |
|         if args.benchmark:
 | |
|             import auto_log
 | |
|             pid = os.getpid()
 | |
|             gpu_id = utility.get_infer_gpuid()
 | |
|             self.autolog = auto_log.AutoLogger(
 | |
|                 model_name="det",
 | |
|                 model_precision=args.precision,
 | |
|                 batch_size=1,
 | |
|                 data_shape="dynamic",
 | |
|                 save_path=None,
 | |
|                 inference_config=self.config,
 | |
|                 pids=pid,
 | |
|                 process_name=None,
 | |
|                 gpu_ids=gpu_id if args.use_gpu else None,
 | |
|                 time_keys=[
 | |
|                     'preprocess_time', 'inference_time', 'postprocess_time'
 | |
|                 ],
 | |
|                 warmup=2,
 | |
|                 logger=logger)
 | |
| 
 | |
|     def order_points_clockwise(self, pts):
 | |
|         """
 | |
|         reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py
 | |
|         # sort the points based on their x-coordinates
 | |
|         """
 | |
|         xSorted = pts[np.argsort(pts[:, 0]), :]
 | |
| 
 | |
|         # grab the left-most and right-most points from the sorted
 | |
|         # x-roodinate points
 | |
|         leftMost = xSorted[:2, :]
 | |
|         rightMost = xSorted[2:, :]
 | |
| 
 | |
|         # now, sort the left-most coordinates according to their
 | |
|         # y-coordinates so we can grab the top-left and bottom-left
 | |
|         # points, respectively
 | |
|         leftMost = leftMost[np.argsort(leftMost[:, 1]), :]
 | |
|         (tl, bl) = leftMost
 | |
| 
 | |
|         rightMost = rightMost[np.argsort(rightMost[:, 1]), :]
 | |
|         (tr, br) = rightMost
 | |
| 
 | |
|         rect = np.array([tl, tr, br, bl], dtype="float32")
 | |
|         return rect
 | |
| 
 | |
|     def clip_det_res(self, points, img_height, img_width):
 | |
|         for pno in range(points.shape[0]):
 | |
|             points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
 | |
|             points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
 | |
|         return points
 | |
| 
 | |
|     def filter_tag_det_res(self, dt_boxes, image_shape):
 | |
|         img_height, img_width = image_shape[0:2]
 | |
|         dt_boxes_new = []
 | |
|         for box in dt_boxes:
 | |
|             box = self.order_points_clockwise(box)
 | |
|             box = self.clip_det_res(box, img_height, img_width)
 | |
|             rect_width = int(np.linalg.norm(box[0] - box[1]))
 | |
|             rect_height = int(np.linalg.norm(box[0] - box[3]))
 | |
|             if rect_width <= 3 or rect_height <= 3:
 | |
|                 continue
 | |
|             dt_boxes_new.append(box)
 | |
|         dt_boxes = np.array(dt_boxes_new)
 | |
|         return dt_boxes
 | |
| 
 | |
|     def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
 | |
|         img_height, img_width = image_shape[0:2]
 | |
|         dt_boxes_new = []
 | |
|         for box in dt_boxes:
 | |
|             box = self.clip_det_res(box, img_height, img_width)
 | |
|             dt_boxes_new.append(box)
 | |
|         dt_boxes = np.array(dt_boxes_new)
 | |
|         return dt_boxes
 | |
| 
 | |
|     def __call__(self, img):
 | |
|         ori_im = img.copy()
 | |
|         data = {'image': img}
 | |
| 
 | |
|         st = time.time()
 | |
| 
 | |
|         if self.args.benchmark:
 | |
|             self.autolog.times.start()
 | |
| 
 | |
|         data = transform(data, self.preprocess_op)
 | |
|         img, shape_list = data
 | |
|         if img is None:
 | |
|             return None, 0
 | |
|         img = np.expand_dims(img, axis=0)
 | |
|         shape_list = np.expand_dims(shape_list, axis=0)
 | |
|         img = img.copy()
 | |
| 
 | |
|         if self.args.benchmark:
 | |
|             self.autolog.times.stamp()
 | |
|         if self.use_onnx:
 | |
|             input_dict = {}
 | |
|             input_dict[self.input_tensor.name] = img
 | |
|             outputs = self.predictor.run(self.output_tensors, input_dict)
 | |
|         else:
 | |
|             self.input_tensor.copy_from_cpu(img)
 | |
|             self.predictor.run()
 | |
|             outputs = []
 | |
|             for output_tensor in self.output_tensors:
 | |
|                 output = output_tensor.copy_to_cpu()
 | |
|                 outputs.append(output)
 | |
|             if self.args.benchmark:
 | |
|                 self.autolog.times.stamp()
 | |
| 
 | |
|         preds = {}
 | |
|         if self.det_algorithm == "EAST":
 | |
|             preds['f_geo'] = outputs[0]
 | |
|             preds['f_score'] = outputs[1]
 | |
|         elif self.det_algorithm == 'SAST':
 | |
|             preds['f_border'] = outputs[0]
 | |
|             preds['f_score'] = outputs[1]
 | |
|             preds['f_tco'] = outputs[2]
 | |
|             preds['f_tvo'] = outputs[3]
 | |
|         elif self.det_algorithm in ['DB', 'PSE']:
 | |
|             preds['maps'] = outputs[0]
 | |
|         else:
 | |
|             raise NotImplementedError
 | |
| 
 | |
|         #self.predictor.try_shrink_memory()
 | |
|         post_result = self.postprocess_op(preds, shape_list)
 | |
|         dt_boxes = post_result[0]['points']
 | |
|         if (self.det_algorithm == "SAST" and
 | |
|                 self.det_sast_polygon) or (self.det_algorithm == "PSE" and
 | |
|                                            self.det_pse_box_type == 'poly'):
 | |
|             dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
 | |
|         else:
 | |
|             dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
 | |
| 
 | |
|         if self.args.benchmark:
 | |
|             self.autolog.times.end(stamp=True)
 | |
|         et = time.time()
 | |
|         return dt_boxes, et - st
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     args = utility.parse_args()
 | |
|     image_file_list = get_image_file_list(args.image_dir)
 | |
|     text_detector = TextDetector(args)
 | |
|     count = 0
 | |
|     total_time = 0
 | |
|     draw_img_save = "./inference_results"
 | |
| 
 | |
|     if args.warmup:
 | |
|         img = np.random.uniform(0, 255, [640, 640, 3]).astype(np.uint8)
 | |
|         for i in range(2):
 | |
|             res = text_detector(img)
 | |
| 
 | |
|     if not os.path.exists(draw_img_save):
 | |
|         os.makedirs(draw_img_save)
 | |
|     save_results = []
 | |
|     for image_file in image_file_list:
 | |
|         img, flag = check_and_read_gif(image_file)
 | |
|         if not flag:
 | |
|             img = cv2.imread(image_file)
 | |
|         if img is None:
 | |
|             logger.info("error in loading image:{}".format(image_file))
 | |
|             continue
 | |
|         st = time.time()
 | |
|         dt_boxes, _ = text_detector(img)
 | |
|         elapse = time.time() - st
 | |
|         if count > 0:
 | |
|             total_time += elapse
 | |
|         count += 1
 | |
|         save_pred = os.path.basename(image_file) + "\t" + str(
 | |
|             json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
 | |
|         save_results.append(save_pred)
 | |
|         logger.info(save_pred)
 | |
|         logger.info("The predict time of {}: {}".format(image_file, elapse))
 | |
|         src_im = utility.draw_text_det_res(dt_boxes, image_file)
 | |
|         img_name_pure = os.path.split(image_file)[-1]
 | |
|         img_path = os.path.join(draw_img_save,
 | |
|                                 "det_res_{}".format(img_name_pure))
 | |
|         cv2.imwrite(img_path, src_im)
 | |
|         logger.info("The visualized image saved in {}".format(img_path))
 | |
| 
 | |
|     with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
 | |
|         f.writelines(save_results)
 | |
|         f.close()
 | |
|     if args.benchmark:
 | |
|         text_detector.autolog.report()
 | 
