mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-07-25 18:00:07 +00:00
add pad for small image in det
This commit is contained in:
parent
48eba02894
commit
dec76eb75d
@ -81,7 +81,7 @@ class NormalizeImage(object):
|
|||||||
assert isinstance(img,
|
assert isinstance(img,
|
||||||
np.ndarray), "invalid input 'img' in NormalizeImage"
|
np.ndarray), "invalid input 'img' in NormalizeImage"
|
||||||
data['image'] = (
|
data['image'] = (
|
||||||
img.astype('float32') * self.scale - self.mean) / self.std
|
img.astype('float32') * self.scale - self.mean) / self.std
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@ -122,6 +122,8 @@ class DetResizeForTest(object):
|
|||||||
elif 'limit_side_len' in kwargs:
|
elif 'limit_side_len' in kwargs:
|
||||||
self.limit_side_len = kwargs['limit_side_len']
|
self.limit_side_len = kwargs['limit_side_len']
|
||||||
self.limit_type = kwargs.get('limit_type', 'min')
|
self.limit_type = kwargs.get('limit_type', 'min')
|
||||||
|
self.pad = kwargs.get('pad', False)
|
||||||
|
self.pad_size = kwargs.get('pad_size', 480)
|
||||||
elif 'resize_long' in kwargs:
|
elif 'resize_long' in kwargs:
|
||||||
self.resize_type = 2
|
self.resize_type = 2
|
||||||
self.resize_long = kwargs.get('resize_long', 960)
|
self.resize_long = kwargs.get('resize_long', 960)
|
||||||
@ -163,7 +165,7 @@ class DetResizeForTest(object):
|
|||||||
img, (ratio_h, ratio_w)
|
img, (ratio_h, ratio_w)
|
||||||
"""
|
"""
|
||||||
limit_side_len = self.limit_side_len
|
limit_side_len = self.limit_side_len
|
||||||
h, w, _ = img.shape
|
h, w, c = img.shape
|
||||||
|
|
||||||
# limit the max side
|
# limit the max side
|
||||||
if self.limit_type == 'max':
|
if self.limit_type == 'max':
|
||||||
@ -172,6 +174,8 @@ class DetResizeForTest(object):
|
|||||||
ratio = float(limit_side_len) / h
|
ratio = float(limit_side_len) / h
|
||||||
else:
|
else:
|
||||||
ratio = float(limit_side_len) / w
|
ratio = float(limit_side_len) / w
|
||||||
|
elif self.pad:
|
||||||
|
ratio = float(self.pad_size) / max(h, w)
|
||||||
else:
|
else:
|
||||||
ratio = 1.
|
ratio = 1.
|
||||||
else:
|
else:
|
||||||
@ -197,6 +201,10 @@ class DetResizeForTest(object):
|
|||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
ratio_h = resize_h / float(h)
|
ratio_h = resize_h / float(h)
|
||||||
ratio_w = resize_w / float(w)
|
ratio_w = resize_w / float(w)
|
||||||
|
if self.limit_type == 'max' and self.pad:
|
||||||
|
padding_im = np.zeros((self.pad_size, self.pad_size, c), dtype=np.float32)
|
||||||
|
padding_im[:resize_h, :resize_w, :] = img
|
||||||
|
img = padding_im
|
||||||
return img, [ratio_h, ratio_w]
|
return img, [ratio_h, ratio_w]
|
||||||
|
|
||||||
def resize_image_type2(self, img):
|
def resize_image_type2(self, img):
|
||||||
|
@ -49,12 +49,12 @@ class DBPostProcess(object):
|
|||||||
self.dilation_kernel = None if not use_dilation else np.array(
|
self.dilation_kernel = None if not use_dilation else np.array(
|
||||||
[[1, 1], [1, 1]])
|
[[1, 1], [1, 1]])
|
||||||
|
|
||||||
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
|
def boxes_from_bitmap(self, pred, _bitmap, shape):
|
||||||
'''
|
'''
|
||||||
_bitmap: single map with shape (1, H, W),
|
_bitmap: single map with shape (1, H, W),
|
||||||
whose values are binarized as {0, 1}
|
whose values are binarized as {0, 1}
|
||||||
'''
|
'''
|
||||||
|
dest_height, dest_width, ratio_h, ratio_w = shape
|
||||||
bitmap = _bitmap
|
bitmap = _bitmap
|
||||||
height, width = bitmap.shape
|
height, width = bitmap.shape
|
||||||
|
|
||||||
@ -89,9 +89,9 @@ class DBPostProcess(object):
|
|||||||
box = np.array(box)
|
box = np.array(box)
|
||||||
|
|
||||||
box[:, 0] = np.clip(
|
box[:, 0] = np.clip(
|
||||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
np.round(box[:, 0] / ratio_w), 0, dest_width)
|
||||||
box[:, 1] = np.clip(
|
box[:, 1] = np.clip(
|
||||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
np.round(box[:, 1] / ratio_h), 0, dest_height)
|
||||||
boxes.append(box.astype(np.int16))
|
boxes.append(box.astype(np.int16))
|
||||||
scores.append(score)
|
scores.append(score)
|
||||||
return np.array(boxes, dtype=np.int16), scores
|
return np.array(boxes, dtype=np.int16), scores
|
||||||
@ -175,7 +175,6 @@ class DBPostProcess(object):
|
|||||||
|
|
||||||
boxes_batch = []
|
boxes_batch = []
|
||||||
for batch_index in range(pred.shape[0]):
|
for batch_index in range(pred.shape[0]):
|
||||||
src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
|
|
||||||
if self.dilation_kernel is not None:
|
if self.dilation_kernel is not None:
|
||||||
mask = cv2.dilate(
|
mask = cv2.dilate(
|
||||||
np.array(segmentation[batch_index]).astype(np.uint8),
|
np.array(segmentation[batch_index]).astype(np.uint8),
|
||||||
@ -183,7 +182,7 @@ class DBPostProcess(object):
|
|||||||
else:
|
else:
|
||||||
mask = segmentation[batch_index]
|
mask = segmentation[batch_index]
|
||||||
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
|
||||||
src_w, src_h)
|
shape_list[batch_index])
|
||||||
|
|
||||||
boxes_batch.append({'points': boxes})
|
boxes_batch.append({'points': boxes})
|
||||||
return boxes_batch
|
return boxes_batch
|
||||||
|
@ -38,11 +38,13 @@ logger = get_logger()
|
|||||||
|
|
||||||
class OCRSystem(object):
|
class OCRSystem(object):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
args.det_pad = True
|
||||||
|
args.det_pad_size = 640
|
||||||
self.text_system = TextSystem(args)
|
self.text_system = TextSystem(args)
|
||||||
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
self.table_system = TableSystem(args, self.text_system.text_detector, self.text_system.text_recognizer)
|
||||||
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
self.table_layout = lp.PaddleDetectionLayoutModel("lp://PubLayNet/ppyolov2_r50vd_dcn_365e_publaynet/config",
|
||||||
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
threshold=0.5, enable_mkldnn=args.enable_mkldnn,
|
||||||
enforce_cpu=not args.use_gpu,thread_num=args.cpu_threads)
|
enforce_cpu=not args.use_gpu, thread_num=args.cpu_threads)
|
||||||
self.use_angle_cls = args.use_angle_cls
|
self.use_angle_cls = args.use_angle_cls
|
||||||
self.drop_score = args.drop_score
|
self.drop_score = args.drop_score
|
||||||
|
|
||||||
@ -67,7 +69,6 @@ class OCRSystem(object):
|
|||||||
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
res_list.append({'type': region.type, 'bbox': [x1, y1, x2, y2], 'res': res})
|
||||||
return res_list
|
return res_list
|
||||||
|
|
||||||
|
|
||||||
def save_res(res, save_folder, img_name):
|
def save_res(res, save_folder, img_name):
|
||||||
excel_save_folder = os.path.join(save_folder, img_name)
|
excel_save_folder = os.path.join(save_folder, img_name)
|
||||||
os.makedirs(excel_save_folder, exist_ok=True)
|
os.makedirs(excel_save_folder, exist_ok=True)
|
||||||
|
@ -41,7 +41,9 @@ class TextDetector(object):
|
|||||||
pre_process_list = [{
|
pre_process_list = [{
|
||||||
'DetResizeForTest': {
|
'DetResizeForTest': {
|
||||||
'limit_side_len': args.det_limit_side_len,
|
'limit_side_len': args.det_limit_side_len,
|
||||||
'limit_type': args.det_limit_type
|
'limit_type': args.det_limit_type,
|
||||||
|
'pad':args.det_pad,
|
||||||
|
'pad_size':args.det_pad_size
|
||||||
}
|
}
|
||||||
}, {
|
}, {
|
||||||
'NormalizeImage': {
|
'NormalizeImage': {
|
||||||
|
@ -46,6 +46,8 @@ def init_args():
|
|||||||
parser.add_argument("--det_model_dir", type=str)
|
parser.add_argument("--det_model_dir", type=str)
|
||||||
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
parser.add_argument("--det_limit_side_len", type=float, default=960)
|
||||||
parser.add_argument("--det_limit_type", type=str, default='max')
|
parser.add_argument("--det_limit_type", type=str, default='max')
|
||||||
|
parser.add_argument("--det_pad", type=str2bool, default=False)
|
||||||
|
parser.add_argument("--det_pad_size", type=int, default=640)
|
||||||
|
|
||||||
# DB parmas
|
# DB parmas
|
||||||
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
parser.add_argument("--det_db_thresh", type=float, default=0.3)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user