combine args in paddleocr and ppocr/infer/utility

This commit is contained in:
WenmuZhou 2021-05-26 17:34:47 +08:00
parent 5d24736a62
commit eaf38b9b12
5 changed files with 309 additions and 180 deletions

View File

@ -59,7 +59,7 @@ im_show.save('result.jpg')
from paddleocr import PaddleOCR, draw_ocr from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR() # need to run only once to download and load model into memory ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg' img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path,cls=False)
for line in result: for line in result:
print(line) print(line)

View File

@ -59,7 +59,7 @@ Visualization of results
from paddleocr import PaddleOCR,draw_ocr from paddleocr import PaddleOCR,draw_ocr
ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path, cls=False)
for line in result: for line in result:
print(line) print(line)

View File

@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr from tools.infer.utility import draw_ocr, inference_args_list, str2bool, parse_args
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
@ -167,106 +167,36 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path) os.remove(tmp_path)
def parse_args(mMain=True, add_help=True): def parse_args_whl(mMain=True):
import argparse import argparse
extend_args_list = [
def str2bool(v): {
return v.lower() in ("true", "t", "1") 'name': 'lang',
'type': str,
'default': 'ch'
},
{
'name': 'det',
'type': str2bool,
'default': True
},
{
'name': 'rec',
'type': str2bool,
'default': True
},
]
for item in inference_args_list:
if item['name'] == 'rec_char_dict_path':
item['default'] = None
inference_args_list.extend(extend_args_list)
if mMain: if mMain:
parser = argparse.ArgumentParser(add_help=add_help) return parse_args()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args()
else: else:
return argparse.Namespace( inference_args_dict = {}
use_gpu=True, for item in inference_args_list:
ir_optim=True, inference_args_dict[item['name']] = item['default']
use_tensorrt=False, return argparse.Namespace(**inference_args_dict)
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=6,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=6,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
@ -276,7 +206,7 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False, add_help=False) postprocess_params = parse_args_whl(mMain=False)
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang lang = postprocess_params.lang
@ -346,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True, cls=False): def ocr(self, img, det=True, rec=True, cls=True):
""" """
ocr with paddleocr ocr with paddleocr
args args
@ -358,9 +288,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False: if cls == True and self.use_angle_cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning( logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
) )
@ -382,7 +310,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec: elif det and not rec:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
@ -392,7 +320,7 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls: if self.use_angle_cls and cls:
img, cls_res, elapse = self.text_classifier(img) img, cls_res, elapse = self.text_classifier(img)
if not rec: if not rec:
return cls_res return cls_res
@ -402,7 +330,7 @@ class PaddleOCR(predict_system.TextSystem):
def main(): def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args_whl(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')

View File

@ -85,7 +85,7 @@ class TextSystem(object):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
logger.info(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.info("dt_boxes num : {}, elapse : {}".format(
@ -100,7 +100,7 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.info("cls num : {}, elapse : {}".format(

View File

@ -23,87 +23,288 @@ import math
from paddle import inference from paddle import inference
def parse_args(): def str2bool(v):
def str2bool(v): return v.lower() in ("true", "t", "1")
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
inference_args_list = [
# params for prediction engine # params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True) {
parser.add_argument("--ir_optim", type=str2bool, default=True) 'name': 'use_gpu',
parser.add_argument("--use_tensorrt", type=str2bool, default=False) 'type': str2bool,
parser.add_argument("--use_fp16", type=str2bool, default=False) 'default': True
parser.add_argument("--gpu_mem", type=int, default=500) },
{
'name': 'ir_optim',
'type': str2bool,
'default': True
},
{
'name': 'use_tensorrt',
'type': str2bool,
'default': False
},
{
'name': 'use_fp16',
'type': str2bool,
'default': False
},
{
'name': 'enable_mkldnn',
'type': str2bool,
'default': False
},
{
'name': 'use_pdserving',
'type': str2bool,
'default': False
},
{
'name': 'use_mp',
'type': str2bool,
'default': False
},
{
'name': 'total_process_num',
'type': int,
'default': 1
},
{
'name': 'process_id',
'type': int,
'default': 0
},
{
'name': 'gpu_mem',
'type': int,
'default': 500
},
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) {
parser.add_argument("--det_algorithm", type=str, default='DB') 'name': 'image_dir',
parser.add_argument("--det_model_dir", type=str) 'type': str,
parser.add_argument("--det_limit_side_len", type=float, default=960) 'default': None
parser.add_argument("--det_limit_type", type=str, default='max') },
{
'name': 'det_algorithm',
'type': str,
'default': 'DB'
},
{
'name': 'det_model_dir',
'type': str,
'default': None
},
{
'name': 'det_limit_side_len',
'type': float,
'default': 960
},
{
'name': 'det_limit_type',
'type': str,
'default': 'max'
},
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) {
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) 'name': 'det_db_thresh',
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) 'type': float,
parser.add_argument("--max_batch_size", type=int, default=10) 'default': 0.3
parser.add_argument("--use_dilation", type=bool, default=False) },
parser.add_argument("--det_db_score_mode", type=str, default="fast") {
'name': 'det_db_box_thresh',
'type': float,
'default': 0.5
},
{
'name': 'det_db_unclip_ratio',
'type': float,
'default': 1.6
},
{
'name': 'max_batch_size',
'type': int,
'default': 10
},
{
'name': 'use_dilation',
'type': str2bool,
'default': False
},
{
'name': 'det_db_score_mode',
'type': str,
'default': 'fast'
},
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) {
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) 'name': 'det_east_score_thresh',
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) 'type': float,
'default': 0.8
},
{
'name': 'det_east_cover_thresh',
'type': float,
'default': 0.1
},
{
'name': 'det_east_nms_thresh',
'type': float,
'default': 0.2
},
# SAST parmas # SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) {
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) 'name': 'det_sast_score_thresh',
parser.add_argument("--det_sast_polygon", type=bool, default=False) 'type': float,
'default': 0.5
},
{
'name': 'det_sast_nms_thresh',
'type': float,
'default': 0.2
},
{
'name': 'det_sast_polygon',
'type': str2bool,
'default': False
},
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') {
parser.add_argument("--rec_model_dir", type=str) 'name': 'rec_algorithm',
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") 'type': str,
parser.add_argument("--rec_char_type", type=str, default='ch') 'default': 'CRNN'
parser.add_argument("--rec_batch_num", type=int, default=6) },
parser.add_argument("--max_text_length", type=int, default=25) {
parser.add_argument( 'name': 'rec_model_dir',
"--rec_char_dict_path", 'type': str,
type=str, 'default': None
default="./ppocr/utils/ppocr_keys_v1.txt") },
parser.add_argument("--use_space_char", type=str2bool, default=True) {
parser.add_argument( 'name': 'rec_image_shape',
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") 'type': str,
parser.add_argument("--drop_score", type=float, default=0.5) 'default': '3, 32, 320'
},
{
'name': 'rec_char_type',
'type': str,
'default': "ch"
},
{
'name': 'rec_batch_num',
'type': int,
'default': 6
},
{
'name': 'max_text_length',
'type': int,
'default': 25
},
{
'name': 'rec_char_dict_path',
'type': str,
'default': './ppocr/utils/ppocr_keys_v1.txt'
},
{
'name': 'use_space_char',
'type': str2bool,
'default': True
},
{
'name': 'vis_font_path',
'type': str,
'default': './doc/fonts/simfang.ttf'
},
{
'name': 'drop_score',
'type': float,
'default': 0.5
},
# params for e2e # params for e2e
parser.add_argument("--e2e_algorithm", type=str, default='PGNet') {
parser.add_argument("--e2e_model_dir", type=str) 'name': 'e2e_algorithm',
parser.add_argument("--e2e_limit_side_len", type=float, default=768) 'type': str,
parser.add_argument("--e2e_limit_type", type=str, default='max') 'default': 'PGNet'
},
{
'name': 'e2e_model_dir',
'type': str,
'default': None
},
{
'name': 'e2e_limit_side_len',
'type': float,
'default': 768
},
{
'name': 'e2e_limit_type',
'type': str,
'default': 'max'
},
# PGNet parmas # PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) {
parser.add_argument( 'name': 'e2e_pgnet_score_thresh',
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") 'type': float,
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') 'default': 0.5
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) },
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') {
'name': 'e2e_char_dict_path',
'type': str,
'default': './ppocr/utils/ic15_dict.txt'
},
{
'name': 'e2e_pgnet_valid_set',
'type': str,
'default': 'totaltext'
},
{
'name': 'e2e_pgnet_polygon',
'type': str2bool,
'default': True
},
{
'name': 'e2e_pgnet_mode',
'type': str,
'default': 'fast'
},
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) {
parser.add_argument("--cls_model_dir", type=str) 'name': 'use_angle_cls',
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") 'type': str2bool,
parser.add_argument("--label_list", type=list, default=['0', '180']) 'default': False
parser.add_argument("--cls_batch_num", type=int, default=6) },
parser.add_argument("--cls_thresh", type=float, default=0.9) {
'name': 'cls_model_dir',
'type': str,
'default': None
},
{
'name': 'cls_image_shape',
'type': str,
'default': '3, 48, 192'
},
{
'name': 'label_list',
'type': list,
'default': ['0', '180']
},
{
'name': 'cls_batch_num',
'type': int,
'default': 6
},
{
'name': 'cls_thresh',
'type': float,
'default': 0.9
},
]
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
def parse_args():
parser = argparse.ArgumentParser()
for item in inference_args_list:
parser.add_argument(
'--' + item['name'], type=item['type'], default=item['default'])
return parser.parse_args() return parser.parse_args()
@ -146,7 +347,7 @@ def create_predictor(args, mode, logger):
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
# TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) # config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# enable memory optim # enable memory optim