mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			870 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			870 lines
		
	
	
		
			31 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import argparse
 | |
| import os
 | |
| import sys
 | |
| import cv2
 | |
| import numpy as np
 | |
| import paddle
 | |
| import PIL
 | |
| from PIL import Image, ImageDraw, ImageFont
 | |
| import math
 | |
| from paddle import inference
 | |
| import random
 | |
| from ppocr.utils.logging import get_logger
 | |
| 
 | |
| 
 | |
| def str2bool(v):
 | |
|     return v.lower() in ("true", "yes", "t", "y", "1")
 | |
| 
 | |
| 
 | |
| def str2int_tuple(v):
 | |
|     return tuple([int(i.strip()) for i in v.split(",")])
 | |
| 
 | |
| 
 | |
| def init_args():
 | |
|     parser = argparse.ArgumentParser()
 | |
|     # params for prediction engine
 | |
|     parser.add_argument("--use_gpu", type=str2bool, default=True)
 | |
|     parser.add_argument("--use_xpu", type=str2bool, default=False)
 | |
|     parser.add_argument("--use_npu", type=str2bool, default=False)
 | |
|     parser.add_argument("--use_mlu", type=str2bool, default=False)
 | |
|     parser.add_argument(
 | |
|         "--use_gcu",
 | |
|         type=str2bool,
 | |
|         default=False,
 | |
|         help="Use Enflame GCU(General Compute Unit)",
 | |
|     )
 | |
|     parser.add_argument("--ir_optim", type=str2bool, default=True)
 | |
|     parser.add_argument("--use_tensorrt", type=str2bool, default=False)
 | |
|     parser.add_argument("--min_subgraph_size", type=int, default=15)
 | |
|     parser.add_argument("--precision", type=str, default="fp32")
 | |
|     parser.add_argument("--gpu_mem", type=int, default=500)
 | |
|     parser.add_argument("--gpu_id", type=int, default=0)
 | |
| 
 | |
|     # params for text detector
 | |
|     parser.add_argument("--image_dir", type=str)
 | |
|     parser.add_argument("--page_num", type=int, default=0)
 | |
|     parser.add_argument("--det_algorithm", type=str, default="DB")
 | |
|     parser.add_argument("--det_model_dir", type=str)
 | |
|     parser.add_argument("--det_limit_side_len", type=float, default=960)
 | |
|     parser.add_argument("--det_limit_type", type=str, default="max")
 | |
|     parser.add_argument("--det_box_type", type=str, default="quad")
 | |
| 
 | |
|     # DB params
 | |
|     parser.add_argument("--det_db_thresh", type=float, default=0.3)
 | |
|     parser.add_argument("--det_db_box_thresh", type=float, default=0.6)
 | |
|     parser.add_argument("--det_db_unclip_ratio", type=float, default=1.5)
 | |
|     parser.add_argument("--max_batch_size", type=int, default=10)
 | |
|     parser.add_argument("--use_dilation", type=str2bool, default=False)
 | |
|     parser.add_argument("--det_db_score_mode", type=str, default="fast")
 | |
| 
 | |
|     # EAST params
 | |
|     parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
 | |
|     parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
 | |
|     parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
 | |
| 
 | |
|     # SAST params
 | |
|     parser.add_argument("--det_sast_score_thresh", type=float, default=0.5)
 | |
|     parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
 | |
| 
 | |
|     # PSE params
 | |
|     parser.add_argument("--det_pse_thresh", type=float, default=0)
 | |
|     parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
 | |
|     parser.add_argument("--det_pse_min_area", type=float, default=16)
 | |
|     parser.add_argument("--det_pse_scale", type=int, default=1)
 | |
| 
 | |
|     # FCE params
 | |
|     parser.add_argument("--scales", type=list, default=[8, 16, 32])
 | |
|     parser.add_argument("--alpha", type=float, default=1.0)
 | |
|     parser.add_argument("--beta", type=float, default=1.0)
 | |
|     parser.add_argument("--fourier_degree", type=int, default=5)
 | |
| 
 | |
|     # params for text recognizer
 | |
|     parser.add_argument("--rec_algorithm", type=str, default="SVTR_LCNet")
 | |
|     parser.add_argument("--rec_model_dir", type=str)
 | |
|     parser.add_argument("--rec_image_inverse", type=str2bool, default=True)
 | |
|     parser.add_argument("--rec_image_shape", type=str, default="3, 48, 320")
 | |
|     parser.add_argument("--rec_batch_num", type=int, default=6)
 | |
|     parser.add_argument("--max_text_length", type=int, default=25)
 | |
|     parser.add_argument(
 | |
|         "--rec_char_dict_path", type=str, default="./ppocr/utils/ppocr_keys_v1.txt"
 | |
|     )
 | |
|     parser.add_argument("--use_space_char", type=str2bool, default=True)
 | |
|     parser.add_argument("--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
 | |
|     parser.add_argument("--drop_score", type=float, default=0.5)
 | |
| 
 | |
|     # params for e2e
 | |
|     parser.add_argument("--e2e_algorithm", type=str, default="PGNet")
 | |
|     parser.add_argument("--e2e_model_dir", type=str)
 | |
|     parser.add_argument("--e2e_limit_side_len", type=float, default=768)
 | |
|     parser.add_argument("--e2e_limit_type", type=str, default="max")
 | |
| 
 | |
|     # PGNet params
 | |
|     parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
 | |
|     parser.add_argument(
 | |
|         "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt"
 | |
|     )
 | |
|     parser.add_argument("--e2e_pgnet_valid_set", type=str, default="totaltext")
 | |
|     parser.add_argument("--e2e_pgnet_mode", type=str, default="fast")
 | |
| 
 | |
|     # params for text classifier
 | |
|     parser.add_argument("--use_angle_cls", type=str2bool, default=False)
 | |
|     parser.add_argument("--cls_model_dir", type=str)
 | |
|     parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
 | |
|     parser.add_argument("--label_list", type=list, default=["0", "180"])
 | |
|     parser.add_argument("--cls_batch_num", type=int, default=6)
 | |
|     parser.add_argument("--cls_thresh", type=float, default=0.9)
 | |
| 
 | |
|     parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
 | |
|     parser.add_argument("--cpu_threads", type=int, default=10)
 | |
|     parser.add_argument("--use_pdserving", type=str2bool, default=False)
 | |
|     parser.add_argument("--warmup", type=str2bool, default=False)
 | |
| 
 | |
|     # SR params
 | |
|     parser.add_argument("--sr_model_dir", type=str)
 | |
|     parser.add_argument("--sr_image_shape", type=str, default="3, 32, 128")
 | |
|     parser.add_argument("--sr_batch_num", type=int, default=1)
 | |
| 
 | |
|     #
 | |
|     parser.add_argument("--draw_img_save_dir", type=str, default="./inference_results")
 | |
|     parser.add_argument("--save_crop_res", type=str2bool, default=False)
 | |
|     parser.add_argument("--crop_res_save_dir", type=str, default="./output")
 | |
| 
 | |
|     # multi-process
 | |
|     parser.add_argument("--use_mp", type=str2bool, default=False)
 | |
|     parser.add_argument("--total_process_num", type=int, default=1)
 | |
|     parser.add_argument("--process_id", type=int, default=0)
 | |
| 
 | |
|     parser.add_argument("--benchmark", type=str2bool, default=False)
 | |
|     parser.add_argument("--save_log_path", type=str, default="./log_output/")
 | |
| 
 | |
|     parser.add_argument("--show_log", type=str2bool, default=True)
 | |
|     parser.add_argument("--use_onnx", type=str2bool, default=False)
 | |
|     parser.add_argument("--onnx_providers", nargs="+", type=str, default=False)
 | |
|     parser.add_argument("--onnx_sess_options", type=list, default=False)
 | |
| 
 | |
|     # extended function
 | |
|     parser.add_argument(
 | |
|         "--return_word_box",
 | |
|         type=str2bool,
 | |
|         default=False,
 | |
|         help="Whether return the bbox of each word (split by space) or chinese character. Only used in ppstructure for layout recovery",
 | |
|     )
 | |
| 
 | |
|     return parser
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     parser = init_args()
 | |
|     return parser.parse_args()
 | |
| 
 | |
| 
 | |
| def create_predictor(args, mode, logger):
 | |
|     if mode == "det":
 | |
|         model_dir = args.det_model_dir
 | |
|     elif mode == "cls":
 | |
|         model_dir = args.cls_model_dir
 | |
|     elif mode == "rec":
 | |
|         model_dir = args.rec_model_dir
 | |
|     elif mode == "table":
 | |
|         model_dir = args.table_model_dir
 | |
|     elif mode == "ser":
 | |
|         model_dir = args.ser_model_dir
 | |
|     elif mode == "re":
 | |
|         model_dir = args.re_model_dir
 | |
|     elif mode == "sr":
 | |
|         model_dir = args.sr_model_dir
 | |
|     elif mode == "layout":
 | |
|         model_dir = args.layout_model_dir
 | |
|     else:
 | |
|         model_dir = args.e2e_model_dir
 | |
| 
 | |
|     if model_dir is None:
 | |
|         logger.info("not find {} model file path {}".format(mode, model_dir))
 | |
|         sys.exit(0)
 | |
|     if args.use_onnx:
 | |
|         import onnxruntime as ort
 | |
| 
 | |
|         model_file_path = model_dir
 | |
|         if not os.path.exists(model_file_path):
 | |
|             raise ValueError("not find model file path {}".format(model_file_path))
 | |
| 
 | |
|         sess_options = args.onnx_sess_options or None
 | |
| 
 | |
|         if args.onnx_providers and len(args.onnx_providers) > 0:
 | |
|             sess = ort.InferenceSession(
 | |
|                 model_file_path,
 | |
|                 providers=args.onnx_providers,
 | |
|                 sess_options=sess_options,
 | |
|             )
 | |
|         elif args.use_gpu:
 | |
|             sess = ort.InferenceSession(
 | |
|                 model_file_path,
 | |
|                 providers=[
 | |
|                     (
 | |
|                         "CUDAExecutionProvider",
 | |
|                         {"device_id": args.gpu_id, "cudnn_conv_algo_search": "DEFAULT"},
 | |
|                     )
 | |
|                 ],
 | |
|                 sess_options=sess_options,
 | |
|             )
 | |
|         else:
 | |
|             sess = ort.InferenceSession(
 | |
|                 model_file_path,
 | |
|                 providers=["CPUExecutionProvider"],
 | |
|                 sess_options=sess_options,
 | |
|             )
 | |
|         inputs = sess.get_inputs()
 | |
|         return (
 | |
|             sess,
 | |
|             inputs[0] if len(inputs) == 1 else [vo.name for vo in inputs],
 | |
|             None,
 | |
|             None,
 | |
|         )
 | |
| 
 | |
|     else:
 | |
|         file_names = ["model", "inference"]
 | |
|         for file_name in file_names:
 | |
|             params_file_path = f"{model_dir}/{file_name}.pdiparams"
 | |
|             if os.path.exists(params_file_path):
 | |
|                 break
 | |
| 
 | |
|         if not os.path.exists(params_file_path):
 | |
|             raise ValueError(f"not find {file_name}.pdiparams in {model_dir}")
 | |
| 
 | |
|         if not (
 | |
|             os.path.exists(f"{model_dir}/{file_name}.pdmodel")
 | |
|             or os.path.exists(f"{model_dir}/{file_name}.json")
 | |
|         ):
 | |
|             raise ValueError(
 | |
|                 f"neither {file_name}.json nor {file_name}.pdmodel was found in {model_dir}."
 | |
|             )
 | |
| 
 | |
|         if os.path.exists(f"{model_dir}/{file_name}.json"):
 | |
|             model_file_path = f"{model_dir}/{file_name}.json"
 | |
|         else:
 | |
|             model_file_path = f"{model_dir}/{file_name}.pdmodel"
 | |
| 
 | |
|         config = inference.Config(model_file_path, params_file_path)
 | |
| 
 | |
|         if hasattr(args, "precision"):
 | |
|             if args.precision == "fp16" and args.use_tensorrt:
 | |
|                 precision = inference.PrecisionType.Half
 | |
|             elif args.precision == "int8":
 | |
|                 precision = inference.PrecisionType.Int8
 | |
|             else:
 | |
|                 precision = inference.PrecisionType.Float32
 | |
|         else:
 | |
|             precision = inference.PrecisionType.Float32
 | |
| 
 | |
|         if args.use_gpu:
 | |
|             gpu_id = get_infer_gpuid()
 | |
|             if gpu_id is None:
 | |
|                 logger.warning(
 | |
|                     "GPU is not found in current device by nvidia-smi. Please check your device or ignore it if run on jetson."
 | |
|                 )
 | |
|             config.enable_use_gpu(args.gpu_mem, args.gpu_id)
 | |
|             if args.use_tensorrt:
 | |
|                 config.enable_tensorrt_engine(
 | |
|                     workspace_size=1 << 30,
 | |
|                     precision_mode=precision,
 | |
|                     max_batch_size=args.max_batch_size,
 | |
|                     min_subgraph_size=args.min_subgraph_size,  # skip the minimum trt subgraph
 | |
|                     use_calib_mode=False,
 | |
|                 )
 | |
| 
 | |
|                 # collect shape
 | |
|                 trt_shape_f = os.path.join(model_dir, f"{mode}_trt_dynamic_shape.txt")
 | |
| 
 | |
|                 if not os.path.exists(trt_shape_f):
 | |
|                     config.collect_shape_range_info(trt_shape_f)
 | |
|                     logger.info(f"collect dynamic shape info into : {trt_shape_f}")
 | |
|                 try:
 | |
|                     config.enable_tuned_tensorrt_dynamic_shape(trt_shape_f, True)
 | |
|                 except Exception as E:
 | |
|                     logger.info(E)
 | |
|                     logger.info("Please keep your paddlepaddle-gpu >= 2.3.0!")
 | |
| 
 | |
|         elif args.use_npu:
 | |
|             config.enable_custom_device("npu")
 | |
|         elif args.use_mlu:
 | |
|             config.enable_custom_device("mlu")
 | |
|         elif args.use_xpu:
 | |
|             config.enable_xpu(10 * 1024 * 1024)
 | |
|         elif args.use_gcu:  # for Enflame GCU(General Compute Unit)
 | |
|             assert paddle.device.is_compiled_with_custom_device("gcu"), (
 | |
|                 "Args use_gcu cannot be set as True while your paddle "
 | |
|                 "is not compiled with gcu! \nPlease try: \n"
 | |
|                 "\t1. Install paddle-custom-gcu to run model on GCU. \n"
 | |
|                 "\t2. Set use_gcu as False in args to run model on CPU."
 | |
|             )
 | |
|             import paddle_custom_device.gcu.passes as gcu_passes
 | |
| 
 | |
|             gcu_passes.setUp()
 | |
|             if args.precision == "fp16":
 | |
|                 config.enable_custom_device(
 | |
|                     "gcu", 0, paddle.inference.PrecisionType.Half
 | |
|                 )
 | |
|                 gcu_passes.set_exp_enable_mixed_precision_ops(config)
 | |
|             else:
 | |
|                 config.enable_custom_device("gcu")
 | |
| 
 | |
|             if paddle.framework.use_pir_api():
 | |
|                 config.enable_new_ir(True)
 | |
|                 config.enable_new_executor(True)
 | |
|             else:
 | |
|                 pass_builder = config.pass_builder()
 | |
|                 gcu_passes.append_passes_for_legacy_ir(pass_builder, "PaddleOCR")
 | |
|         else:
 | |
|             config.disable_gpu()
 | |
|             if args.enable_mkldnn:
 | |
|                 # cache 10 different shapes for mkldnn to avoid memory leak
 | |
|                 config.set_mkldnn_cache_capacity(10)
 | |
|                 config.enable_mkldnn()
 | |
|                 if args.precision == "fp16":
 | |
|                     config.enable_mkldnn_bfloat16()
 | |
|                 if hasattr(args, "cpu_threads"):
 | |
|                     config.set_cpu_math_library_num_threads(args.cpu_threads)
 | |
|                 else:
 | |
|                     # default cpu threads as 10
 | |
|                     config.set_cpu_math_library_num_threads(10)
 | |
|         # enable memory optim
 | |
|         config.enable_memory_optim()
 | |
|         config.disable_glog_info()
 | |
|         if not args.use_gcu:  # for Enflame GCU(General Compute Unit)
 | |
|             config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
 | |
|         config.delete_pass("matmul_transpose_reshape_fuse_pass")
 | |
|         if mode == "rec" and args.rec_algorithm == "SRN":
 | |
|             config.delete_pass("gpu_cpu_map_matmul_v2_to_matmul_pass")
 | |
|         if mode == "re":
 | |
|             config.delete_pass("simplify_with_basic_ops_pass")
 | |
|         if mode == "table":
 | |
|             config.delete_pass("fc_fuse_pass")  # not supported for table
 | |
|         config.switch_use_feed_fetch_ops(False)
 | |
|         config.switch_ir_optim(True)
 | |
| 
 | |
|         # create predictor
 | |
|         predictor = inference.create_predictor(config)
 | |
|         input_names = predictor.get_input_names()
 | |
|         if mode in ["ser", "re"]:
 | |
|             input_tensor = []
 | |
|             for name in input_names:
 | |
|                 input_tensor.append(predictor.get_input_handle(name))
 | |
|         else:
 | |
|             for name in input_names:
 | |
|                 input_tensor = predictor.get_input_handle(name)
 | |
|         output_tensors = get_output_tensors(args, mode, predictor)
 | |
|         return predictor, input_tensor, output_tensors, config
 | |
| 
 | |
| 
 | |
| def get_output_tensors(args, mode, predictor):
 | |
|     output_names = predictor.get_output_names()
 | |
|     output_tensors = []
 | |
|     if mode == "rec" and args.rec_algorithm in ["CRNN", "SVTR_LCNet", "SVTR_HGNet"]:
 | |
|         output_name = "softmax_0.tmp_0"
 | |
|         if output_name in output_names:
 | |
|             return [predictor.get_output_handle(output_name)]
 | |
|         else:
 | |
|             for output_name in output_names:
 | |
|                 output_tensor = predictor.get_output_handle(output_name)
 | |
|                 output_tensors.append(output_tensor)
 | |
|     else:
 | |
|         for output_name in output_names:
 | |
|             output_tensor = predictor.get_output_handle(output_name)
 | |
|             output_tensors.append(output_tensor)
 | |
|     return output_tensors
 | |
| 
 | |
| 
 | |
| def get_infer_gpuid():
 | |
|     """
 | |
|     Get the GPU ID to be used for inference.
 | |
| 
 | |
|     Returns:
 | |
|         int: The GPU ID to be used for inference.
 | |
|     """
 | |
|     logger = get_logger()
 | |
|     if not paddle.device.is_compiled_with_rocm:
 | |
|         gpu_id_str = os.environ.get("CUDA_VISIBLE_DEVICES", "0")
 | |
|     else:
 | |
|         gpu_id_str = os.environ.get("HIP_VISIBLE_DEVICES", "0")
 | |
| 
 | |
|     gpu_ids = gpu_id_str.split(",")
 | |
|     logger.warning(
 | |
|         "The first GPU is used for inference by default, GPU ID: {}".format(gpu_ids[0])
 | |
|     )
 | |
|     return int(gpu_ids[0])
 | |
| 
 | |
| 
 | |
| def draw_e2e_res(dt_boxes, strs, img_path):
 | |
|     src_im = cv2.imread(img_path)
 | |
|     for box, str in zip(dt_boxes, strs):
 | |
|         box = box.astype(np.int32).reshape((-1, 1, 2))
 | |
|         cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
 | |
|         cv2.putText(
 | |
|             src_im,
 | |
|             str,
 | |
|             org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
 | |
|             fontFace=cv2.FONT_HERSHEY_COMPLEX,
 | |
|             fontScale=0.7,
 | |
|             color=(0, 255, 0),
 | |
|             thickness=1,
 | |
|         )
 | |
|     return src_im
 | |
| 
 | |
| 
 | |
| def draw_text_det_res(dt_boxes, img):
 | |
|     for box in dt_boxes:
 | |
|         box = np.array(box).astype(np.int32).reshape(-1, 2)
 | |
|         cv2.polylines(img, [box], True, color=(255, 255, 0), thickness=2)
 | |
|     return img
 | |
| 
 | |
| 
 | |
| def resize_img(img, input_size=600):
 | |
|     """
 | |
|     resize img and limit the longest side of the image to input_size
 | |
|     """
 | |
|     img = np.array(img)
 | |
|     im_shape = img.shape
 | |
|     im_size_max = np.max(im_shape[0:2])
 | |
|     im_scale = float(input_size) / float(im_size_max)
 | |
|     img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale)
 | |
|     return img
 | |
| 
 | |
| 
 | |
| def draw_ocr(
 | |
|     image,
 | |
|     boxes,
 | |
|     txts=None,
 | |
|     scores=None,
 | |
|     drop_score=0.5,
 | |
|     font_path="./doc/fonts/simfang.ttf",
 | |
| ):
 | |
|     """
 | |
|     Visualize the results of OCR detection and recognition
 | |
|     args:
 | |
|         image(Image|array): RGB image
 | |
|         boxes(list): boxes with shape(N, 4, 2)
 | |
|         txts(list): the texts
 | |
|         scores(list): txxs corresponding scores
 | |
|         drop_score(float): only scores greater than drop_threshold will be visualized
 | |
|         font_path: the path of font which is used to draw text
 | |
|     return(array):
 | |
|         the visualized img
 | |
|     """
 | |
|     if scores is None:
 | |
|         scores = [1] * len(boxes)
 | |
|     box_num = len(boxes)
 | |
|     for i in range(box_num):
 | |
|         if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])):
 | |
|             continue
 | |
|         box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64)
 | |
|         image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
 | |
|     if txts is not None:
 | |
|         img = np.array(resize_img(image, input_size=600))
 | |
|         txt_img = text_visual(
 | |
|             txts,
 | |
|             scores,
 | |
|             img_h=img.shape[0],
 | |
|             img_w=600,
 | |
|             threshold=drop_score,
 | |
|             font_path=font_path,
 | |
|         )
 | |
|         img = np.concatenate([np.array(img), np.array(txt_img)], axis=1)
 | |
|         return img
 | |
|     return image
 | |
| 
 | |
| 
 | |
| def draw_ocr_box_txt(
 | |
|     image,
 | |
|     boxes,
 | |
|     txts=None,
 | |
|     scores=None,
 | |
|     drop_score=0.5,
 | |
|     font_path="./doc/fonts/simfang.ttf",
 | |
| ):
 | |
|     h, w = image.height, image.width
 | |
|     img_left = image.copy()
 | |
|     img_right = np.ones((h, w, 3), dtype=np.uint8) * 255
 | |
|     random.seed(0)
 | |
| 
 | |
|     draw_left = ImageDraw.Draw(img_left)
 | |
|     if txts is None or len(txts) != len(boxes):
 | |
|         txts = [None] * len(boxes)
 | |
|     for idx, (box, txt) in enumerate(zip(boxes, txts)):
 | |
|         if scores is not None and scores[idx] < drop_score:
 | |
|             continue
 | |
|         color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
 | |
|         draw_left.polygon(box, fill=color)
 | |
|         img_right_text = draw_box_txt_fine((w, h), box, txt, font_path)
 | |
|         pts = np.array(box, np.int32).reshape((-1, 1, 2))
 | |
|         cv2.polylines(img_right_text, [pts], True, color, 1)
 | |
|         img_right = cv2.bitwise_and(img_right, img_right_text)
 | |
|     img_left = Image.blend(image, img_left, 0.5)
 | |
|     img_show = Image.new("RGB", (w * 2, h), (255, 255, 255))
 | |
|     img_show.paste(img_left, (0, 0, w, h))
 | |
|     img_show.paste(Image.fromarray(img_right), (w, 0, w * 2, h))
 | |
|     return np.array(img_show)
 | |
| 
 | |
| 
 | |
| def draw_box_txt_fine(img_size, box, txt, font_path="./doc/fonts/simfang.ttf"):
 | |
|     box_height = int(
 | |
|         math.sqrt((box[0][0] - box[3][0]) ** 2 + (box[0][1] - box[3][1]) ** 2)
 | |
|     )
 | |
|     box_width = int(
 | |
|         math.sqrt((box[0][0] - box[1][0]) ** 2 + (box[0][1] - box[1][1]) ** 2)
 | |
|     )
 | |
| 
 | |
|     if box_height > 2 * box_width and box_height > 30:
 | |
|         img_text = Image.new("RGB", (box_height, box_width), (255, 255, 255))
 | |
|         draw_text = ImageDraw.Draw(img_text)
 | |
|         if txt:
 | |
|             font = create_font(txt, (box_height, box_width), font_path)
 | |
|             draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
 | |
|         img_text = img_text.transpose(Image.ROTATE_270)
 | |
|     else:
 | |
|         img_text = Image.new("RGB", (box_width, box_height), (255, 255, 255))
 | |
|         draw_text = ImageDraw.Draw(img_text)
 | |
|         if txt:
 | |
|             font = create_font(txt, (box_width, box_height), font_path)
 | |
|             draw_text.text([0, 0], txt, fill=(0, 0, 0), font=font)
 | |
| 
 | |
|     pts1 = np.float32(
 | |
|         [[0, 0], [box_width, 0], [box_width, box_height], [0, box_height]]
 | |
|     )
 | |
|     pts2 = np.array(box, dtype=np.float32)
 | |
|     M = cv2.getPerspectiveTransform(pts1, pts2)
 | |
| 
 | |
|     img_text = np.array(img_text, dtype=np.uint8)
 | |
|     img_right_text = cv2.warpPerspective(
 | |
|         img_text,
 | |
|         M,
 | |
|         img_size,
 | |
|         flags=cv2.INTER_NEAREST,
 | |
|         borderMode=cv2.BORDER_CONSTANT,
 | |
|         borderValue=(255, 255, 255),
 | |
|     )
 | |
|     return img_right_text
 | |
| 
 | |
| 
 | |
| def create_font(txt, sz, font_path="./doc/fonts/simfang.ttf"):
 | |
|     font_size = int(sz[1] * 0.99)
 | |
|     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
 | |
|     if int(PIL.__version__.split(".")[0]) < 10:
 | |
|         length = font.getsize(txt)[0]
 | |
|     else:
 | |
|         length = font.getlength(txt)
 | |
| 
 | |
|     if length > sz[0]:
 | |
|         font_size = int(font_size * sz[0] / length)
 | |
|         font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
 | |
|     return font
 | |
| 
 | |
| 
 | |
| def str_count(s):
 | |
|     """
 | |
|     Count the number of Chinese characters,
 | |
|     a single English character and a single number
 | |
|     equal to half the length of Chinese characters.
 | |
|     args:
 | |
|         s(string): the input of string
 | |
|     return(int):
 | |
|         the number of Chinese characters
 | |
|     """
 | |
|     import string
 | |
| 
 | |
|     count_zh = count_pu = 0
 | |
|     s_len = len(s)
 | |
|     en_dg_count = 0
 | |
|     for c in s:
 | |
|         if c in string.ascii_letters or c.isdigit() or c.isspace():
 | |
|             en_dg_count += 1
 | |
|         elif c.isalpha():
 | |
|             count_zh += 1
 | |
|         else:
 | |
|             count_pu += 1
 | |
|     return s_len - math.ceil(en_dg_count / 2)
 | |
| 
 | |
| 
 | |
| def text_visual(
 | |
|     texts, scores, img_h=400, img_w=600, threshold=0.0, font_path="./doc/simfang.ttf"
 | |
| ):
 | |
|     """
 | |
|     create new blank img and draw txt on it
 | |
|     args:
 | |
|         texts(list): the text will be draw
 | |
|         scores(list|None): corresponding score of each txt
 | |
|         img_h(int): the height of blank img
 | |
|         img_w(int): the width of blank img
 | |
|         font_path: the path of font which is used to draw text
 | |
|     return(array):
 | |
|     """
 | |
|     if scores is not None:
 | |
|         assert len(texts) == len(
 | |
|             scores
 | |
|         ), "The number of txts and corresponding scores must match"
 | |
| 
 | |
|     def create_blank_img():
 | |
|         blank_img = np.ones(shape=[img_h, img_w], dtype=np.uint8) * 255
 | |
|         blank_img[:, img_w - 1 :] = 0
 | |
|         blank_img = Image.fromarray(blank_img).convert("RGB")
 | |
|         draw_txt = ImageDraw.Draw(blank_img)
 | |
|         return blank_img, draw_txt
 | |
| 
 | |
|     blank_img, draw_txt = create_blank_img()
 | |
| 
 | |
|     font_size = 20
 | |
|     txt_color = (0, 0, 0)
 | |
|     font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
 | |
| 
 | |
|     gap = font_size + 5
 | |
|     txt_img_list = []
 | |
|     count, index = 1, 0
 | |
|     for idx, txt in enumerate(texts):
 | |
|         index += 1
 | |
|         if scores[idx] < threshold or math.isnan(scores[idx]):
 | |
|             index -= 1
 | |
|             continue
 | |
|         first_line = True
 | |
|         while str_count(txt) >= img_w // font_size - 4:
 | |
|             tmp = txt
 | |
|             txt = tmp[: img_w // font_size - 4]
 | |
|             if first_line:
 | |
|                 new_txt = str(index) + ": " + txt
 | |
|                 first_line = False
 | |
|             else:
 | |
|                 new_txt = "    " + txt
 | |
|             draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
 | |
|             txt = tmp[img_w // font_size - 4 :]
 | |
|             if count >= img_h // gap - 1:
 | |
|                 txt_img_list.append(np.array(blank_img))
 | |
|                 blank_img, draw_txt = create_blank_img()
 | |
|                 count = 0
 | |
|             count += 1
 | |
|         if first_line:
 | |
|             new_txt = str(index) + ": " + txt + "   " + "%.3f" % (scores[idx])
 | |
|         else:
 | |
|             new_txt = "  " + txt + "  " + "%.3f" % (scores[idx])
 | |
|         draw_txt.text((0, gap * count), new_txt, txt_color, font=font)
 | |
|         # whether add new blank img or not
 | |
|         if count >= img_h // gap - 1 and idx + 1 < len(texts):
 | |
|             txt_img_list.append(np.array(blank_img))
 | |
|             blank_img, draw_txt = create_blank_img()
 | |
|             count = 0
 | |
|         count += 1
 | |
|     txt_img_list.append(np.array(blank_img))
 | |
|     if len(txt_img_list) == 1:
 | |
|         blank_img = np.array(txt_img_list[0])
 | |
|     else:
 | |
|         blank_img = np.concatenate(txt_img_list, axis=1)
 | |
|     return np.array(blank_img)
 | |
| 
 | |
| 
 | |
| def base64_to_cv2(b64str):
 | |
|     import base64
 | |
| 
 | |
|     data = base64.b64decode(b64str.encode("utf8"))
 | |
|     data = np.frombuffer(data, np.uint8)
 | |
|     data = cv2.imdecode(data, cv2.IMREAD_COLOR)
 | |
|     return data
 | |
| 
 | |
| 
 | |
| def draw_boxes(image, boxes, scores=None, drop_score=0.5):
 | |
|     if scores is None:
 | |
|         scores = [1] * len(boxes)
 | |
|     for box, score in zip(boxes, scores):
 | |
|         if score < drop_score:
 | |
|             continue
 | |
|         box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64)
 | |
|         image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2)
 | |
|     return image
 | |
| 
 | |
| 
 | |
| def get_rotate_crop_image(img, points):
 | |
|     """
 | |
|     img_height, img_width = img.shape[0:2]
 | |
|     left = int(np.min(points[:, 0]))
 | |
|     right = int(np.max(points[:, 0]))
 | |
|     top = int(np.min(points[:, 1]))
 | |
|     bottom = int(np.max(points[:, 1]))
 | |
|     img_crop = img[top:bottom, left:right, :].copy()
 | |
|     points[:, 0] = points[:, 0] - left
 | |
|     points[:, 1] = points[:, 1] - top
 | |
|     """
 | |
|     assert len(points) == 4, "shape of points must be 4*2"
 | |
|     img_crop_width = int(
 | |
|         max(
 | |
|             np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
 | |
|         )
 | |
|     )
 | |
|     img_crop_height = int(
 | |
|         max(
 | |
|             np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
 | |
|         )
 | |
|     )
 | |
|     pts_std = np.float32(
 | |
|         [
 | |
|             [0, 0],
 | |
|             [img_crop_width, 0],
 | |
|             [img_crop_width, img_crop_height],
 | |
|             [0, img_crop_height],
 | |
|         ]
 | |
|     )
 | |
|     M = cv2.getPerspectiveTransform(points, pts_std)
 | |
|     dst_img = cv2.warpPerspective(
 | |
|         img,
 | |
|         M,
 | |
|         (img_crop_width, img_crop_height),
 | |
|         borderMode=cv2.BORDER_REPLICATE,
 | |
|         flags=cv2.INTER_CUBIC,
 | |
|     )
 | |
|     dst_img_height, dst_img_width = dst_img.shape[0:2]
 | |
|     if dst_img_height * 1.0 / dst_img_width >= 1.5:
 | |
|         dst_img = np.rot90(dst_img)
 | |
|     return dst_img
 | |
| 
 | |
| 
 | |
| def get_minarea_rect_crop(img, points):
 | |
|     bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
 | |
|     points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
 | |
| 
 | |
|     index_a, index_b, index_c, index_d = 0, 1, 2, 3
 | |
|     if points[1][1] > points[0][1]:
 | |
|         index_a = 0
 | |
|         index_d = 1
 | |
|     else:
 | |
|         index_a = 1
 | |
|         index_d = 0
 | |
|     if points[3][1] > points[2][1]:
 | |
|         index_b = 2
 | |
|         index_c = 3
 | |
|     else:
 | |
|         index_b = 3
 | |
|         index_c = 2
 | |
| 
 | |
|     box = [points[index_a], points[index_b], points[index_c], points[index_d]]
 | |
|     crop_img = get_rotate_crop_image(img, np.array(box))
 | |
|     return crop_img
 | |
| 
 | |
| 
 | |
| def slice_generator(image, horizontal_stride, vertical_stride, maximum_slices=500):
 | |
|     if not isinstance(image, np.ndarray):
 | |
|         image = np.array(image)
 | |
| 
 | |
|     image_h, image_w = image.shape[:2]
 | |
|     vertical_num_slices = (image_h + vertical_stride - 1) // vertical_stride
 | |
|     horizontal_num_slices = (image_w + horizontal_stride - 1) // horizontal_stride
 | |
| 
 | |
|     assert (
 | |
|         vertical_num_slices > 0
 | |
|     ), f"Invalid number ({vertical_num_slices}) of vertical slices"
 | |
| 
 | |
|     assert (
 | |
|         horizontal_num_slices > 0
 | |
|     ), f"Invalid number ({horizontal_num_slices}) of horizontal slices"
 | |
| 
 | |
|     if vertical_num_slices >= maximum_slices:
 | |
|         recommended_vertical_stride = max(1, image_h // maximum_slices) + 1
 | |
|         assert (
 | |
|             False
 | |
|         ), f"Too computationally expensive with {vertical_num_slices} slices, try a higher vertical stride (recommended minimum: {recommended_vertical_stride})"
 | |
| 
 | |
|     if horizontal_num_slices >= maximum_slices:
 | |
|         recommended_horizontal_stride = max(1, image_w // maximum_slices) + 1
 | |
|         assert (
 | |
|             False
 | |
|         ), f"Too computationally expensive with {horizontal_num_slices} slices, try a higher horizontal stride (recommended minimum: {recommended_horizontal_stride})"
 | |
| 
 | |
|     for v_slice_idx in range(vertical_num_slices):
 | |
|         v_start = max(0, (v_slice_idx * vertical_stride))
 | |
|         v_end = min(((v_slice_idx + 1) * vertical_stride), image_h)
 | |
|         vertical_slice = image[v_start:v_end, :]
 | |
|         for h_slice_idx in range(horizontal_num_slices):
 | |
|             h_start = max(0, (h_slice_idx * horizontal_stride))
 | |
|             h_end = min(((h_slice_idx + 1) * horizontal_stride), image_w)
 | |
|             horizontal_slice = vertical_slice[:, h_start:h_end]
 | |
| 
 | |
|             yield (horizontal_slice, v_start, h_start)
 | |
| 
 | |
| 
 | |
| def calculate_box_extents(box):
 | |
|     min_x = box[0][0]
 | |
|     max_x = box[1][0]
 | |
|     min_y = box[0][1]
 | |
|     max_y = box[2][1]
 | |
|     return min_x, max_x, min_y, max_y
 | |
| 
 | |
| 
 | |
| def merge_boxes(box1, box2, x_threshold, y_threshold):
 | |
|     min_x1, max_x1, min_y1, max_y1 = calculate_box_extents(box1)
 | |
|     min_x2, max_x2, min_y2, max_y2 = calculate_box_extents(box2)
 | |
| 
 | |
|     if (
 | |
|         abs(min_y1 - min_y2) <= y_threshold
 | |
|         and abs(max_y1 - max_y2) <= y_threshold
 | |
|         and abs(max_x1 - min_x2) <= x_threshold
 | |
|     ):
 | |
|         new_xmin = min(min_x1, min_x2)
 | |
|         new_xmax = max(max_x1, max_x2)
 | |
|         new_ymin = min(min_y1, min_y2)
 | |
|         new_ymax = max(max_y1, max_y2)
 | |
|         return [
 | |
|             [new_xmin, new_ymin],
 | |
|             [new_xmax, new_ymin],
 | |
|             [new_xmax, new_ymax],
 | |
|             [new_xmin, new_ymax],
 | |
|         ]
 | |
|     else:
 | |
|         return None
 | |
| 
 | |
| 
 | |
| def merge_fragmented(boxes, x_threshold=10, y_threshold=10):
 | |
|     merged_boxes = []
 | |
|     visited = set()
 | |
| 
 | |
|     for i, box1 in enumerate(boxes):
 | |
|         if i in visited:
 | |
|             continue
 | |
| 
 | |
|         merged_box = [point[:] for point in box1]
 | |
| 
 | |
|         for j, box2 in enumerate(boxes[i + 1 :], start=i + 1):
 | |
|             if j not in visited:
 | |
|                 merged_result = merge_boxes(
 | |
|                     merged_box, box2, x_threshold=x_threshold, y_threshold=y_threshold
 | |
|                 )
 | |
|                 if merged_result:
 | |
|                     merged_box = merged_result
 | |
|                     visited.add(j)
 | |
| 
 | |
|         merged_boxes.append(merged_box)
 | |
| 
 | |
|     if len(merged_boxes) == len(boxes):
 | |
|         return np.array(merged_boxes)
 | |
|     else:
 | |
|         return merge_fragmented(merged_boxes, x_threshold, y_threshold)
 | |
| 
 | |
| 
 | |
| def check_gpu(use_gpu):
 | |
|     if use_gpu and (
 | |
|         not paddle.is_compiled_with_cuda() or paddle.device.get_device() == "cpu"
 | |
|     ):
 | |
|         use_gpu = False
 | |
|     return use_gpu
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     pass
 | 
