mirror of
https://github.com/PaddlePaddle/PaddleOCR.git
synced 2025-11-01 18:29:18 +00:00
优化tools/infer/predict_system.py代码
This commit is contained in:
parent
a28ef7f026
commit
ef156b19e1
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import sys
|
||||
__dir__ = os.path.dirname(__file__)
|
||||
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append(__dir__)
|
||||
sys.path.append(os.path.join(__dir__, '../..'))
|
||||
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
|
||||
|
||||
import tools.infer.utility as utility
|
||||
from ppocr.utils.utility import initial_logger
|
||||
@ -39,6 +39,7 @@ class TextSystem(object):
|
||||
self.text_recognizer = predict_rec.TextRecognizer(args)
|
||||
|
||||
def get_rotate_crop_image(self, img, points):
|
||||
'''
|
||||
img_height, img_width = img.shape[0:2]
|
||||
left = int(np.min(points[:, 0]))
|
||||
right = int(np.max(points[:, 0]))
|
||||
@ -47,15 +48,19 @@ class TextSystem(object):
|
||||
img_crop = img[top:bottom, left:right, :].copy()
|
||||
points[:, 0] = points[:, 0] - left
|
||||
points[:, 1] = points[:, 1] - top
|
||||
img_crop_width = int(np.linalg.norm(points[0] - points[1]))
|
||||
img_crop_height = int(np.linalg.norm(points[0] - points[3]))
|
||||
pts_std = np.float32([[0, 0], [img_crop_width, 0],\
|
||||
[img_crop_width, img_crop_height], [0, img_crop_height]])
|
||||
'''
|
||||
img_crop_width = int(max(np.linalg.norm(points[0] - points[1]),
|
||||
np.linalg.norm(points[2] - points[3])))
|
||||
img_crop_height = int(max(np.linalg.norm(points[0] - points[3]),
|
||||
np.linalg.norm(points[1] - points[2])))
|
||||
pts_std = np.float32([[0, 0],
|
||||
[img_crop_width, 0],
|
||||
[img_crop_width, img_crop_height],
|
||||
[0, img_crop_height]])
|
||||
M = cv2.getPerspectiveTransform(points, pts_std)
|
||||
dst_img = cv2.warpPerspective(
|
||||
img_crop,
|
||||
M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE)
|
||||
dst_img = cv2.warpPerspective(img, M, (img_crop_width, img_crop_height),
|
||||
borderMode=cv2.BORDER_REPLICATE,
|
||||
flags=cv2.INTER_CUBIC)
|
||||
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
||||
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
||||
dst_img = np.rot90(dst_img)
|
||||
@ -106,8 +111,7 @@ def sorted_boxes(dt_boxes):
|
||||
return _boxes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = utility.parse_args()
|
||||
def main(args):
|
||||
image_file_list = get_image_file_list(args.image_dir)
|
||||
text_sys = TextSystem(args)
|
||||
is_visualize = True
|
||||
@ -145,3 +149,7 @@ if __name__ == "__main__":
|
||||
draw_img[:, :, ::-1])
|
||||
print("The visualized image saved in {}".format(
|
||||
os.path.join(draw_img_save, os.path.basename(image_file))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(utility.parse_args())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user