mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-11-04 11:49:14 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			416 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			416 lines
		
	
	
		
			17 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
						|
#
 | 
						|
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
						|
# you may not use this file except in compliance with the License.
 | 
						|
# You may obtain a copy of the License at
 | 
						|
#
 | 
						|
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
# Unless required by applicable law or agreed to in writing, software
 | 
						|
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
						|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
						|
# See the License for the specific language governing permissions and
 | 
						|
# limitations under the License.
 | 
						|
import os
 | 
						|
import sys
 | 
						|
from PIL import Image
 | 
						|
__dir__ = os.path.dirname(os.path.abspath(__file__))
 | 
						|
sys.path.append(__dir__)
 | 
						|
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
 | 
						|
 | 
						|
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
 | 
						|
 | 
						|
import cv2
 | 
						|
import numpy as np
 | 
						|
import math
 | 
						|
import time
 | 
						|
import traceback
 | 
						|
import paddle
 | 
						|
 | 
						|
import tools.infer.utility as utility
 | 
						|
from ppocr.postprocess import build_post_process
 | 
						|
from ppocr.utils.logging import get_logger
 | 
						|
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
 | 
						|
 | 
						|
logger = get_logger()
 | 
						|
 | 
						|
 | 
						|
class TextRecognizer(object):
 | 
						|
    def __init__(self, args):
 | 
						|
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
 | 
						|
        self.rec_batch_num = args.rec_batch_num
 | 
						|
        self.rec_algorithm = args.rec_algorithm
 | 
						|
        postprocess_params = {
 | 
						|
            'name': 'CTCLabelDecode',
 | 
						|
            "character_dict_path": args.rec_char_dict_path,
 | 
						|
            "use_space_char": args.use_space_char
 | 
						|
        }
 | 
						|
        if self.rec_algorithm == "SRN":
 | 
						|
            postprocess_params = {
 | 
						|
                'name': 'SRNLabelDecode',
 | 
						|
                "character_dict_path": args.rec_char_dict_path,
 | 
						|
                "use_space_char": args.use_space_char
 | 
						|
            }
 | 
						|
        elif self.rec_algorithm == "RARE":
 | 
						|
            postprocess_params = {
 | 
						|
                'name': 'AttnLabelDecode',
 | 
						|
                "character_dict_path": args.rec_char_dict_path,
 | 
						|
                "use_space_char": args.use_space_char
 | 
						|
            }
 | 
						|
        elif self.rec_algorithm == 'NRTR':
 | 
						|
            postprocess_params = {
 | 
						|
                'name': 'NRTRLabelDecode',
 | 
						|
                "character_dict_path": args.rec_char_dict_path,
 | 
						|
                "use_space_char": args.use_space_char
 | 
						|
            }
 | 
						|
        elif self.rec_algorithm == "SAR":
 | 
						|
            postprocess_params = {
 | 
						|
                'name': 'SARLabelDecode',
 | 
						|
                "character_dict_path": args.rec_char_dict_path,
 | 
						|
                "use_space_char": args.use_space_char
 | 
						|
            }
 | 
						|
        self.postprocess_op = build_post_process(postprocess_params)
 | 
						|
        self.predictor, self.input_tensor, self.output_tensors, self.config = \
 | 
						|
            utility.create_predictor(args, 'rec', logger)
 | 
						|
        self.benchmark = args.benchmark
 | 
						|
        self.use_onnx = args.use_onnx
 | 
						|
        if args.benchmark:
 | 
						|
            import auto_log
 | 
						|
            pid = os.getpid()
 | 
						|
            gpu_id = utility.get_infer_gpuid()
 | 
						|
            self.autolog = auto_log.AutoLogger(
 | 
						|
                model_name="rec",
 | 
						|
                model_precision=args.precision,
 | 
						|
                batch_size=args.rec_batch_num,
 | 
						|
                data_shape="dynamic",
 | 
						|
                save_path=None,  #args.save_log_path,
 | 
						|
                inference_config=self.config,
 | 
						|
                pids=pid,
 | 
						|
                process_name=None,
 | 
						|
                gpu_ids=gpu_id if args.use_gpu else None,
 | 
						|
                time_keys=[
 | 
						|
                    'preprocess_time', 'inference_time', 'postprocess_time'
 | 
						|
                ],
 | 
						|
                warmup=0,
 | 
						|
                logger=logger)
 | 
						|
 | 
						|
    def resize_norm_img(self, img, max_wh_ratio):
 | 
						|
        imgC, imgH, imgW = self.rec_image_shape
 | 
						|
        if self.rec_algorithm == 'NRTR':
 | 
						|
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
 | 
						|
            # return padding_im
 | 
						|
            image_pil = Image.fromarray(np.uint8(img))
 | 
						|
            img = image_pil.resize([100, 32], Image.ANTIALIAS)
 | 
						|
            img = np.array(img)
 | 
						|
            norm_img = np.expand_dims(img, -1)
 | 
						|
            norm_img = norm_img.transpose((2, 0, 1))
 | 
						|
            return norm_img.astype(np.float32) / 128. - 1.
 | 
						|
 | 
						|
        assert imgC == img.shape[2]
 | 
						|
        imgW = int((32 * max_wh_ratio))
 | 
						|
        if self.use_onnx:
 | 
						|
            w = self.input_tensor.shape[3:][0]
 | 
						|
            if w is not None and w > 0:
 | 
						|
                imgW = w
 | 
						|
 | 
						|
        h, w = img.shape[:2]
 | 
						|
        ratio = w / float(h)
 | 
						|
        if math.ceil(imgH * ratio) > imgW:
 | 
						|
            resized_w = imgW
 | 
						|
        else:
 | 
						|
            resized_w = int(math.ceil(imgH * ratio))
 | 
						|
        resized_image = cv2.resize(img, (resized_w, imgH))
 | 
						|
        resized_image = resized_image.astype('float32')
 | 
						|
        resized_image = resized_image.transpose((2, 0, 1)) / 255
 | 
						|
        resized_image -= 0.5
 | 
						|
        resized_image /= 0.5
 | 
						|
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
 | 
						|
        padding_im[:, :, 0:resized_w] = resized_image
 | 
						|
        return padding_im
 | 
						|
 | 
						|
    def resize_norm_img_srn(self, img, image_shape):
 | 
						|
        imgC, imgH, imgW = image_shape
 | 
						|
 | 
						|
        img_black = np.zeros((imgH, imgW))
 | 
						|
        im_hei = img.shape[0]
 | 
						|
        im_wid = img.shape[1]
 | 
						|
 | 
						|
        if im_wid <= im_hei * 1:
 | 
						|
            img_new = cv2.resize(img, (imgH * 1, imgH))
 | 
						|
        elif im_wid <= im_hei * 2:
 | 
						|
            img_new = cv2.resize(img, (imgH * 2, imgH))
 | 
						|
        elif im_wid <= im_hei * 3:
 | 
						|
            img_new = cv2.resize(img, (imgH * 3, imgH))
 | 
						|
        else:
 | 
						|
            img_new = cv2.resize(img, (imgW, imgH))
 | 
						|
 | 
						|
        img_np = np.asarray(img_new)
 | 
						|
        img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
 | 
						|
        img_black[:, 0:img_np.shape[1]] = img_np
 | 
						|
        img_black = img_black[:, :, np.newaxis]
 | 
						|
 | 
						|
        row, col, c = img_black.shape
 | 
						|
        c = 1
 | 
						|
 | 
						|
        return np.reshape(img_black, (c, row, col)).astype(np.float32)
 | 
						|
 | 
						|
    def srn_other_inputs(self, image_shape, num_heads, max_text_length):
 | 
						|
 | 
						|
        imgC, imgH, imgW = image_shape
 | 
						|
        feature_dim = int((imgH / 8) * (imgW / 8))
 | 
						|
 | 
						|
        encoder_word_pos = np.array(range(0, feature_dim)).reshape(
 | 
						|
            (feature_dim, 1)).astype('int64')
 | 
						|
        gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
 | 
						|
            (max_text_length, 1)).astype('int64')
 | 
						|
 | 
						|
        gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
 | 
						|
        gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
 | 
						|
            [-1, 1, max_text_length, max_text_length])
 | 
						|
        gsrm_slf_attn_bias1 = np.tile(
 | 
						|
            gsrm_slf_attn_bias1,
 | 
						|
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]
 | 
						|
 | 
						|
        gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
 | 
						|
            [-1, 1, max_text_length, max_text_length])
 | 
						|
        gsrm_slf_attn_bias2 = np.tile(
 | 
						|
            gsrm_slf_attn_bias2,
 | 
						|
            [1, num_heads, 1, 1]).astype('float32') * [-1e9]
 | 
						|
 | 
						|
        encoder_word_pos = encoder_word_pos[np.newaxis, :]
 | 
						|
        gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
 | 
						|
 | 
						|
        return [
 | 
						|
            encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
 | 
						|
            gsrm_slf_attn_bias2
 | 
						|
        ]
 | 
						|
 | 
						|
    def process_image_srn(self, img, image_shape, num_heads, max_text_length):
 | 
						|
        norm_img = self.resize_norm_img_srn(img, image_shape)
 | 
						|
        norm_img = norm_img[np.newaxis, :]
 | 
						|
 | 
						|
        [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
 | 
						|
            self.srn_other_inputs(image_shape, num_heads, max_text_length)
 | 
						|
 | 
						|
        gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
 | 
						|
        gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
 | 
						|
        encoder_word_pos = encoder_word_pos.astype(np.int64)
 | 
						|
        gsrm_word_pos = gsrm_word_pos.astype(np.int64)
 | 
						|
 | 
						|
        return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
 | 
						|
                gsrm_slf_attn_bias2)
 | 
						|
 | 
						|
    def resize_norm_img_sar(self, img, image_shape,
 | 
						|
                            width_downsample_ratio=0.25):
 | 
						|
        imgC, imgH, imgW_min, imgW_max = image_shape
 | 
						|
        h = img.shape[0]
 | 
						|
        w = img.shape[1]
 | 
						|
        valid_ratio = 1.0
 | 
						|
        # make sure new_width is an integral multiple of width_divisor.
 | 
						|
        width_divisor = int(1 / width_downsample_ratio)
 | 
						|
        # resize
 | 
						|
        ratio = w / float(h)
 | 
						|
        resize_w = math.ceil(imgH * ratio)
 | 
						|
        if resize_w % width_divisor != 0:
 | 
						|
            resize_w = round(resize_w / width_divisor) * width_divisor
 | 
						|
        if imgW_min is not None:
 | 
						|
            resize_w = max(imgW_min, resize_w)
 | 
						|
        if imgW_max is not None:
 | 
						|
            valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
 | 
						|
            resize_w = min(imgW_max, resize_w)
 | 
						|
        resized_image = cv2.resize(img, (resize_w, imgH))
 | 
						|
        resized_image = resized_image.astype('float32')
 | 
						|
        # norm 
 | 
						|
        if image_shape[0] == 1:
 | 
						|
            resized_image = resized_image / 255
 | 
						|
            resized_image = resized_image[np.newaxis, :]
 | 
						|
        else:
 | 
						|
            resized_image = resized_image.transpose((2, 0, 1)) / 255
 | 
						|
        resized_image -= 0.5
 | 
						|
        resized_image /= 0.5
 | 
						|
        resize_shape = resized_image.shape
 | 
						|
        padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
 | 
						|
        padding_im[:, :, 0:resize_w] = resized_image
 | 
						|
        pad_shape = padding_im.shape
 | 
						|
 | 
						|
        return padding_im, resize_shape, pad_shape, valid_ratio
 | 
						|
 | 
						|
    def __call__(self, img_list):
 | 
						|
        img_num = len(img_list)
 | 
						|
        # Calculate the aspect ratio of all text bars
 | 
						|
        width_list = []
 | 
						|
        for img in img_list:
 | 
						|
            width_list.append(img.shape[1] / float(img.shape[0]))
 | 
						|
        # Sorting can speed up the recognition process
 | 
						|
        indices = np.argsort(np.array(width_list))
 | 
						|
        rec_res = [['', 0.0]] * img_num
 | 
						|
        batch_num = self.rec_batch_num
 | 
						|
        st = time.time()
 | 
						|
        if self.benchmark:
 | 
						|
            self.autolog.times.start()
 | 
						|
        for beg_img_no in range(0, img_num, batch_num):
 | 
						|
            end_img_no = min(img_num, beg_img_no + batch_num)
 | 
						|
            norm_img_batch = []
 | 
						|
            max_wh_ratio = 0
 | 
						|
            for ino in range(beg_img_no, end_img_no):
 | 
						|
                h, w = img_list[indices[ino]].shape[0:2]
 | 
						|
                wh_ratio = w * 1.0 / h
 | 
						|
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
 | 
						|
            for ino in range(beg_img_no, end_img_no):
 | 
						|
                if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
 | 
						|
                    norm_img = self.resize_norm_img(img_list[indices[ino]],
 | 
						|
                                                    max_wh_ratio)
 | 
						|
                    norm_img = norm_img[np.newaxis, :]
 | 
						|
                    norm_img_batch.append(norm_img)
 | 
						|
                elif self.rec_algorithm == "SAR":
 | 
						|
                    norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
 | 
						|
                        img_list[indices[ino]], self.rec_image_shape)
 | 
						|
                    norm_img = norm_img[np.newaxis, :]
 | 
						|
                    valid_ratio = np.expand_dims(valid_ratio, axis=0)
 | 
						|
                    valid_ratios = []
 | 
						|
                    valid_ratios.append(valid_ratio)
 | 
						|
                    norm_img_batch.append(norm_img)
 | 
						|
                else:
 | 
						|
                    norm_img = self.process_image_srn(
 | 
						|
                        img_list[indices[ino]], self.rec_image_shape, 8, 25)
 | 
						|
                    encoder_word_pos_list = []
 | 
						|
                    gsrm_word_pos_list = []
 | 
						|
                    gsrm_slf_attn_bias1_list = []
 | 
						|
                    gsrm_slf_attn_bias2_list = []
 | 
						|
                    encoder_word_pos_list.append(norm_img[1])
 | 
						|
                    gsrm_word_pos_list.append(norm_img[2])
 | 
						|
                    gsrm_slf_attn_bias1_list.append(norm_img[3])
 | 
						|
                    gsrm_slf_attn_bias2_list.append(norm_img[4])
 | 
						|
                    norm_img_batch.append(norm_img[0])
 | 
						|
            norm_img_batch = np.concatenate(norm_img_batch)
 | 
						|
            norm_img_batch = norm_img_batch.copy()
 | 
						|
            if self.benchmark:
 | 
						|
                self.autolog.times.stamp()
 | 
						|
 | 
						|
            if self.rec_algorithm == "SRN":
 | 
						|
                encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
 | 
						|
                gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
 | 
						|
                gsrm_slf_attn_bias1_list = np.concatenate(
 | 
						|
                    gsrm_slf_attn_bias1_list)
 | 
						|
                gsrm_slf_attn_bias2_list = np.concatenate(
 | 
						|
                    gsrm_slf_attn_bias2_list)
 | 
						|
 | 
						|
                inputs = [
 | 
						|
                    norm_img_batch,
 | 
						|
                    encoder_word_pos_list,
 | 
						|
                    gsrm_word_pos_list,
 | 
						|
                    gsrm_slf_attn_bias1_list,
 | 
						|
                    gsrm_slf_attn_bias2_list,
 | 
						|
                ]
 | 
						|
                if self.use_onnx:
 | 
						|
                    input_dict = {}
 | 
						|
                    input_dict[self.input_tensor.name] = norm_img_batch
 | 
						|
                    outputs = self.predictor.run(self.output_tensors,
 | 
						|
                                                 input_dict)
 | 
						|
                    preds = {"predict": outputs[2]}
 | 
						|
                else:
 | 
						|
                    input_names = self.predictor.get_input_names()
 | 
						|
                    for i in range(len(input_names)):
 | 
						|
                        input_tensor = self.predictor.get_input_handle(
 | 
						|
                            input_names[i])
 | 
						|
                        input_tensor.copy_from_cpu(inputs[i])
 | 
						|
                    self.predictor.run()
 | 
						|
                    outputs = []
 | 
						|
                    for output_tensor in self.output_tensors:
 | 
						|
                        output = output_tensor.copy_to_cpu()
 | 
						|
                        outputs.append(output)
 | 
						|
                    if self.benchmark:
 | 
						|
                        self.autolog.times.stamp()
 | 
						|
                    preds = {"predict": outputs[2]}
 | 
						|
            elif self.rec_algorithm == "SAR":
 | 
						|
                valid_ratios = np.concatenate(valid_ratios)
 | 
						|
                inputs = [
 | 
						|
                    norm_img_batch,
 | 
						|
                    valid_ratios,
 | 
						|
                ]
 | 
						|
                if self.use_onnx:
 | 
						|
                    input_dict = {}
 | 
						|
                    input_dict[self.input_tensor.name] = norm_img_batch
 | 
						|
                    outputs = self.predictor.run(self.output_tensors,
 | 
						|
                                                 input_dict)
 | 
						|
                    preds = outputs[0]
 | 
						|
                else:
 | 
						|
                    input_names = self.predictor.get_input_names()
 | 
						|
                    for i in range(len(input_names)):
 | 
						|
                        input_tensor = self.predictor.get_input_handle(
 | 
						|
                            input_names[i])
 | 
						|
                        input_tensor.copy_from_cpu(inputs[i])
 | 
						|
                    self.predictor.run()
 | 
						|
                    outputs = []
 | 
						|
                    for output_tensor in self.output_tensors:
 | 
						|
                        output = output_tensor.copy_to_cpu()
 | 
						|
                        outputs.append(output)
 | 
						|
                    if self.benchmark:
 | 
						|
                        self.autolog.times.stamp()
 | 
						|
                    preds = outputs[0]
 | 
						|
            else:
 | 
						|
                if self.use_onnx:
 | 
						|
                    input_dict = {}
 | 
						|
                    input_dict[self.input_tensor.name] = norm_img_batch
 | 
						|
                    outputs = self.predictor.run(self.output_tensors,
 | 
						|
                                                 input_dict)
 | 
						|
                    preds = outputs[0]
 | 
						|
                else:
 | 
						|
                    self.input_tensor.copy_from_cpu(norm_img_batch)
 | 
						|
                    self.predictor.run()
 | 
						|
                    outputs = []
 | 
						|
                    for output_tensor in self.output_tensors:
 | 
						|
                        output = output_tensor.copy_to_cpu()
 | 
						|
                        outputs.append(output)
 | 
						|
                    if self.benchmark:
 | 
						|
                        self.autolog.times.stamp()
 | 
						|
                    if len(outputs) != 1:
 | 
						|
                        preds = outputs
 | 
						|
                    else:
 | 
						|
                        preds = outputs[0]
 | 
						|
            rec_result = self.postprocess_op(preds)
 | 
						|
            for rno in range(len(rec_result)):
 | 
						|
                rec_res[indices[beg_img_no + rno]] = rec_result[rno]
 | 
						|
            if self.benchmark:
 | 
						|
                self.autolog.times.end(stamp=True)
 | 
						|
        return rec_res, time.time() - st
 | 
						|
 | 
						|
 | 
						|
def main(args):
 | 
						|
    image_file_list = get_image_file_list(args.image_dir)
 | 
						|
    text_recognizer = TextRecognizer(args)
 | 
						|
    valid_image_file_list = []
 | 
						|
    img_list = []
 | 
						|
 | 
						|
    # warmup 2 times
 | 
						|
    if args.warmup:
 | 
						|
        img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
 | 
						|
        for i in range(2):
 | 
						|
            res = text_recognizer([img] * int(args.rec_batch_num))
 | 
						|
 | 
						|
    for image_file in image_file_list:
 | 
						|
        img, flag = check_and_read_gif(image_file)
 | 
						|
        if not flag:
 | 
						|
            img = cv2.imread(image_file)
 | 
						|
        if img is None:
 | 
						|
            logger.info("error in loading image:{}".format(image_file))
 | 
						|
            continue
 | 
						|
        valid_image_file_list.append(image_file)
 | 
						|
        img_list.append(img)
 | 
						|
    try:
 | 
						|
        rec_res, _ = text_recognizer(img_list)
 | 
						|
 | 
						|
    except Exception as E:
 | 
						|
        logger.info(traceback.format_exc())
 | 
						|
        logger.info(E)
 | 
						|
        exit()
 | 
						|
    for ino in range(len(img_list)):
 | 
						|
        logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
 | 
						|
                                               rec_res[ino]))
 | 
						|
    if args.benchmark:
 | 
						|
        text_recognizer.autolog.report()
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main(utility.parse_args())
 |