| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import sys | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | __dir__ = os.path.dirname(__file__) | 
					
						
							|  |  |  | sys.path.append(__dir__) | 
					
						
							|  |  |  | sys.path.append(os.path.join(__dir__, '..')) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | from .locality_aware_nms import nms_locality | 
					
						
							| 
									
										
										
										
											2020-12-15 12:54:02 +00:00
										 |  |  | import paddle | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | import cv2 | 
					
						
							|  |  |  | import time | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class SASTPostProcess(object): | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  |     The post process for SAST. | 
					
						
							|  |  |  |     """
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  score_thresh=0.5, | 
					
						
							|  |  |  |                  nms_thresh=0.2, | 
					
						
							|  |  |  |                  sample_pts_num=2, | 
					
						
							|  |  |  |                  shrink_ratio_of_width=0.3, | 
					
						
							|  |  |  |                  expand_scale=1.0, | 
					
						
							|  |  |  |                  tcl_map_thresh=0.5, | 
					
						
							|  |  |  |                  **kwargs): | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         self.score_thresh = score_thresh | 
					
						
							|  |  |  |         self.nms_thresh = nms_thresh | 
					
						
							|  |  |  |         self.sample_pts_num = sample_pts_num | 
					
						
							|  |  |  |         self.shrink_ratio_of_width = shrink_ratio_of_width | 
					
						
							|  |  |  |         self.expand_scale = expand_scale | 
					
						
							|  |  |  |         self.tcl_map_thresh = tcl_map_thresh | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         # c++ la-nms is faster, but only support python 3.5 | 
					
						
							|  |  |  |         self.is_python35 = False | 
					
						
							|  |  |  |         if sys.version_info.major == 3 and sys.version_info.minor == 5: | 
					
						
							|  |  |  |             self.is_python35 = True | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |     def point_pair2poly(self, point_pair_list): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Transfer vertical point_pairs into poly point in clockwise. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # constract poly | 
					
						
							|  |  |  |         point_num = len(point_pair_list) * 2 | 
					
						
							|  |  |  |         point_list = [0] * point_num | 
					
						
							|  |  |  |         for idx, point_pair in enumerate(point_pair_list): | 
					
						
							|  |  |  |             point_list[idx] = point_pair[0] | 
					
						
							|  |  |  |             point_list[point_num - 1 - idx] = point_pair[1] | 
					
						
							|  |  |  |         return np.array(point_list).reshape(-1, 2) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def shrink_quad_along_width(self, | 
					
						
							|  |  |  |                                 quad, | 
					
						
							|  |  |  |                                 begin_width_ratio=0., | 
					
						
							|  |  |  |                                 end_width_ratio=1.): | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         """ 
 | 
					
						
							|  |  |  |         Generate shrink_quad_along_width. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         ratio_pair = np.array( | 
					
						
							|  |  |  |             [[begin_width_ratio], [end_width_ratio]], dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair | 
					
						
							|  |  |  |         p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair | 
					
						
							|  |  |  |         return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]]) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |     def expand_poly_along_width(self, poly, shrink_ratio_of_width=0.3): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         expand poly along width. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         point_num = poly.shape[0] | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         left_quad = np.array( | 
					
						
							|  |  |  |             [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \ | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                      (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6) | 
					
						
							|  |  |  |         left_quad_expand = self.shrink_quad_along_width(left_quad, left_ratio, | 
					
						
							|  |  |  |                                                         1.0) | 
					
						
							|  |  |  |         right_quad = np.array( | 
					
						
							|  |  |  |             [ | 
					
						
							|  |  |  |                 poly[point_num // 2 - 2], poly[point_num // 2 - 1], | 
					
						
							|  |  |  |                 poly[point_num // 2], poly[point_num // 2 + 1] | 
					
						
							|  |  |  |             ], | 
					
						
							|  |  |  |             dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         right_ratio = 1.0 + \ | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                       shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \ | 
					
						
							|  |  |  |                       (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6) | 
					
						
							|  |  |  |         right_quad_expand = self.shrink_quad_along_width(right_quad, 0.0, | 
					
						
							|  |  |  |                                                          right_ratio) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         poly[0] = left_quad_expand[0] | 
					
						
							|  |  |  |         poly[-1] = left_quad_expand[-1] | 
					
						
							|  |  |  |         poly[point_num // 2 - 1] = right_quad_expand[1] | 
					
						
							|  |  |  |         poly[point_num // 2] = right_quad_expand[2] | 
					
						
							|  |  |  |         return poly | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def restore_quad(self, tcl_map, tcl_map_thresh, tvo_map): | 
					
						
							|  |  |  |         """Restore quad.""" | 
					
						
							|  |  |  |         xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         xy_text = xy_text[:, ::-1]  # (n, 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # Sort the text boxes via the y axis | 
					
						
							|  |  |  |         xy_text = xy_text[np.argsort(xy_text[:, 1])] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] | 
					
						
							|  |  |  |         scores = scores[:, np.newaxis] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Restore | 
					
						
							|  |  |  |         point_num = int(tvo_map.shape[-1] / 2) | 
					
						
							|  |  |  |         assert point_num == 4 | 
					
						
							|  |  |  |         tvo_map = tvo_map[xy_text[:, 1], xy_text[:, 0], :] | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         xy_text_tile = np.tile(xy_text, (1, point_num))  # (n, point_num * 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         quads = xy_text_tile - tvo_map | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return scores, quads, xy_text | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def quad_area(self, quad): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         compute area of a quad. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         edge = [(quad[1][0] - quad[0][0]) * (quad[1][1] + quad[0][1]), | 
					
						
							|  |  |  |                 (quad[2][0] - quad[1][0]) * (quad[2][1] + quad[1][1]), | 
					
						
							|  |  |  |                 (quad[3][0] - quad[2][0]) * (quad[3][1] + quad[2][1]), | 
					
						
							|  |  |  |                 (quad[0][0] - quad[3][0]) * (quad[0][1] + quad[3][1])] | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         return np.sum(edge) / 2. | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |     def nms(self, dets): | 
					
						
							|  |  |  |         if self.is_python35: | 
					
						
							|  |  |  |             import lanms | 
					
						
							|  |  |  |             dets = lanms.merge_quadrangle_n9(dets, self.nms_thresh) | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             dets = nms_locality(dets, self.nms_thresh) | 
					
						
							|  |  |  |         return dets | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def cluster_by_quads_tco(self, tcl_map, tcl_map_thresh, quads, tco_map): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Cluster pixels in tcl_map based on quads. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         instance_count = quads.shape[0] + 1  # contain background | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         instance_label_map = np.zeros(tcl_map.shape[:2], dtype=np.int32) | 
					
						
							|  |  |  |         if instance_count == 1: | 
					
						
							|  |  |  |             return instance_count, instance_label_map | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # predict text center | 
					
						
							|  |  |  |         xy_text = np.argwhere(tcl_map[:, :, 0] > tcl_map_thresh) | 
					
						
							|  |  |  |         n = xy_text.shape[0] | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         xy_text = xy_text[:, ::-1]  # (n, 2) | 
					
						
							|  |  |  |         tco = tco_map[xy_text[:, 1], xy_text[:, 0], :]  # (n, 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         pred_tc = xy_text - tco | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         # get gt text center | 
					
						
							|  |  |  |         m = quads.shape[0] | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         gt_tc = np.mean(quads, axis=1)  # (m, 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         pred_tc_tile = np.tile(pred_tc[:, np.newaxis, :], | 
					
						
							|  |  |  |                                (1, m, 1))  # (n, m, 2) | 
					
						
							|  |  |  |         gt_tc_tile = np.tile(gt_tc[np.newaxis, :, :], (n, 1, 1))  # (n, m, 2) | 
					
						
							|  |  |  |         dist_mat = np.linalg.norm(pred_tc_tile - gt_tc_tile, axis=2)  # (n, m) | 
					
						
							|  |  |  |         xy_text_assign = np.argmin(dist_mat, axis=1) + 1  # (n,) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         instance_label_map[xy_text[:, 1], xy_text[:, 0]] = xy_text_assign | 
					
						
							|  |  |  |         return instance_count, instance_label_map | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def estimate_sample_pts_num(self, quad, xy_text): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Estimate sample points number. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         eh = (np.linalg.norm(quad[0] - quad[3]) + | 
					
						
							|  |  |  |               np.linalg.norm(quad[1] - quad[2])) / 2.0 | 
					
						
							|  |  |  |         ew = (np.linalg.norm(quad[0] - quad[1]) + | 
					
						
							|  |  |  |               np.linalg.norm(quad[2] - quad[3])) / 2.0 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         dense_sample_pts_num = max(2, int(ew)) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         dense_xy_center_line = xy_text[np.linspace( | 
					
						
							|  |  |  |             0, | 
					
						
							|  |  |  |             xy_text.shape[0] - 1, | 
					
						
							|  |  |  |             dense_sample_pts_num, | 
					
						
							|  |  |  |             endpoint=True, | 
					
						
							|  |  |  |             dtype=np.float32).astype(np.int32)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         dense_xy_center_line_diff = dense_xy_center_line[ | 
					
						
							|  |  |  |             1:] - dense_xy_center_line[:-1] | 
					
						
							|  |  |  |         estimate_arc_len = np.sum( | 
					
						
							|  |  |  |             np.linalg.norm( | 
					
						
							|  |  |  |                 dense_xy_center_line_diff, axis=1)) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         sample_pts_num = max(2, int(estimate_arc_len / eh)) | 
					
						
							|  |  |  |         return sample_pts_num | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |     def detect_sast(self, | 
					
						
							|  |  |  |                     tcl_map, | 
					
						
							|  |  |  |                     tvo_map, | 
					
						
							|  |  |  |                     tbo_map, | 
					
						
							|  |  |  |                     tco_map, | 
					
						
							|  |  |  |                     ratio_w, | 
					
						
							|  |  |  |                     ratio_h, | 
					
						
							|  |  |  |                     src_w, | 
					
						
							|  |  |  |                     src_h, | 
					
						
							|  |  |  |                     shrink_ratio_of_width=0.3, | 
					
						
							|  |  |  |                     tcl_map_thresh=0.5, | 
					
						
							|  |  |  |                     offset_expand=1.0, | 
					
						
							|  |  |  |                     out_strid=4.0): | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         """
 | 
					
						
							|  |  |  |         first resize the tcl_map, tvo_map and tbo_map to the input_size, then restore the polys | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         # restore quad | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         scores, quads, xy_text = self.restore_quad(tcl_map, tcl_map_thresh, | 
					
						
							|  |  |  |                                                    tvo_map) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         dets = np.hstack((quads, scores)).astype(np.float32, copy=False) | 
					
						
							|  |  |  |         dets = self.nms(dets) | 
					
						
							|  |  |  |         if dets.shape[0] == 0: | 
					
						
							|  |  |  |             return [] | 
					
						
							|  |  |  |         quads = dets[:, :-1].reshape(-1, 4, 2) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Compute quad area | 
					
						
							|  |  |  |         quad_areas = [] | 
					
						
							|  |  |  |         for quad in quads: | 
					
						
							|  |  |  |             quad_areas.append(-self.quad_area(quad)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # instance segmentation | 
					
						
							|  |  |  |         # instance_count, instance_label_map = cv2.connectedComponents(tcl_map.astype(np.uint8), connectivity=8) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |         instance_count, instance_label_map = self.cluster_by_quads_tco( | 
					
						
							|  |  |  |             tcl_map, tcl_map_thresh, quads, tco_map) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |         # restore single poly with tcl instance. | 
					
						
							|  |  |  |         poly_list = [] | 
					
						
							|  |  |  |         for instance_idx in range(1, instance_count): | 
					
						
							|  |  |  |             xy_text = np.argwhere(instance_label_map == instance_idx)[:, ::-1] | 
					
						
							|  |  |  |             quad = quads[instance_idx - 1] | 
					
						
							|  |  |  |             q_area = quad_areas[instance_idx - 1] | 
					
						
							|  |  |  |             if q_area < 5: | 
					
						
							|  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             # | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             len1 = float(np.linalg.norm(quad[0] - quad[1])) | 
					
						
							|  |  |  |             len2 = float(np.linalg.norm(quad[1] - quad[2])) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             min_len = min(len1, len2) | 
					
						
							|  |  |  |             if min_len < 3: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # filter small CC | 
					
						
							|  |  |  |             if xy_text.shape[0] <= 0: | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # filter low confidence instance | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             xy_text_scores = tcl_map[xy_text[:, 1], xy_text[:, 0], 0] | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.1: | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                 # if np.sum(xy_text_scores) / quad_areas[instance_idx - 1] < 0.05: | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |                 continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # sort xy_text | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             left_center_pt = np.array( | 
					
						
							|  |  |  |                 [[(quad[0, 0] + quad[-1, 0]) / 2.0, | 
					
						
							|  |  |  |                   (quad[0, 1] + quad[-1, 1]) / 2.0]])  # (1, 2) | 
					
						
							|  |  |  |             right_center_pt = np.array( | 
					
						
							|  |  |  |                 [[(quad[1, 0] + quad[2, 0]) / 2.0, | 
					
						
							|  |  |  |                   (quad[1, 1] + quad[2, 1]) / 2.0]])  # (1, 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             proj_unit_vec = (right_center_pt - left_center_pt) / \ | 
					
						
							|  |  |  |                             (np.linalg.norm(right_center_pt - left_center_pt) + 1e-6) | 
					
						
							|  |  |  |             proj_value = np.sum(xy_text * proj_unit_vec, axis=1) | 
					
						
							|  |  |  |             xy_text = xy_text[np.argsort(proj_value)] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Sample pts in tcl map | 
					
						
							|  |  |  |             if self.sample_pts_num == 0: | 
					
						
							|  |  |  |                 sample_pts_num = self.estimate_sample_pts_num(quad, xy_text) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 sample_pts_num = self.sample_pts_num | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             xy_center_line = xy_text[np.linspace( | 
					
						
							|  |  |  |                 0, | 
					
						
							|  |  |  |                 xy_text.shape[0] - 1, | 
					
						
							|  |  |  |                 sample_pts_num, | 
					
						
							|  |  |  |                 endpoint=True, | 
					
						
							|  |  |  |                 dtype=np.float32).astype(np.int32)] | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |             point_pair_list = [] | 
					
						
							|  |  |  |             for x, y in xy_center_line: | 
					
						
							|  |  |  |                 # get corresponding offset | 
					
						
							|  |  |  |                 offset = tbo_map[y, x, :].reshape(2, 2) | 
					
						
							|  |  |  |                 if offset_expand != 1.0: | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                     offset_length = np.linalg.norm( | 
					
						
							|  |  |  |                         offset, axis=1, keepdims=True) | 
					
						
							|  |  |  |                     expand_length = np.clip( | 
					
						
							|  |  |  |                         offset_length * (offset_expand - 1), | 
					
						
							|  |  |  |                         a_min=0.5, | 
					
						
							|  |  |  |                         a_max=3.0) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |                     offset_detal = offset / offset_length * expand_length | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                     offset = offset + offset_detal | 
					
						
							|  |  |  |                     # original point | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |                 ori_yx = np.array([y, x], dtype=np.float32) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |                 point_pair = (ori_yx + offset)[:, ::-1] * out_strid / np.array( | 
					
						
							|  |  |  |                     [ratio_w, ratio_h]).reshape(-1, 2) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |                 point_pair_list.append(point_pair) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # ndarry: (x, 2), expand poly along width | 
					
						
							|  |  |  |             detected_poly = self.point_pair2poly(point_pair_list) | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             detected_poly = self.expand_poly_along_width(detected_poly, | 
					
						
							|  |  |  |                                                          shrink_ratio_of_width) | 
					
						
							|  |  |  |             detected_poly[:, 0] = np.clip( | 
					
						
							|  |  |  |                 detected_poly[:, 0], a_min=0, a_max=src_w) | 
					
						
							|  |  |  |             detected_poly[:, 1] = np.clip( | 
					
						
							|  |  |  |                 detected_poly[:, 1], a_min=0, a_max=src_h) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             poly_list.append(detected_poly) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return poly_list | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |     def __call__(self, outs_dict, shape_list): | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         score_list = outs_dict['f_score'] | 
					
						
							|  |  |  |         border_list = outs_dict['f_border'] | 
					
						
							|  |  |  |         tvo_list = outs_dict['f_tvo'] | 
					
						
							|  |  |  |         tco_list = outs_dict['f_tco'] | 
					
						
							| 
									
										
										
										
											2020-12-15 12:54:02 +00:00
										 |  |  |         if isinstance(score_list, paddle.Tensor): | 
					
						
							|  |  |  |             score_list = score_list.numpy() | 
					
						
							|  |  |  |             border_list = border_list.numpy() | 
					
						
							|  |  |  |             tvo_list = tvo_list.numpy() | 
					
						
							|  |  |  |             tco_list = tco_list.numpy() | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |         img_num = len(shape_list) | 
					
						
							|  |  |  |         poly_lists = [] | 
					
						
							|  |  |  |         for ino in range(img_num): | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             p_score = score_list[ino].transpose((1, 2, 0)) | 
					
						
							|  |  |  |             p_border = border_list[ino].transpose((1, 2, 0)) | 
					
						
							|  |  |  |             p_tvo = tvo_list[ino].transpose((1, 2, 0)) | 
					
						
							|  |  |  |             p_tco = tco_list[ino].transpose((1, 2, 0)) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             src_h, src_w, ratio_h, ratio_w = shape_list[ino] | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-03-08 14:15:47 +08:00
										 |  |  |             poly_list = self.detect_sast( | 
					
						
							|  |  |  |                 p_score, | 
					
						
							|  |  |  |                 p_tvo, | 
					
						
							|  |  |  |                 p_border, | 
					
						
							|  |  |  |                 p_tco, | 
					
						
							|  |  |  |                 ratio_w, | 
					
						
							|  |  |  |                 ratio_h, | 
					
						
							|  |  |  |                 src_w, | 
					
						
							|  |  |  |                 src_h, | 
					
						
							|  |  |  |                 shrink_ratio_of_width=self.shrink_ratio_of_width, | 
					
						
							|  |  |  |                 tcl_map_thresh=self.tcl_map_thresh, | 
					
						
							|  |  |  |                 offset_expand=self.expand_scale) | 
					
						
							| 
									
										
										
										
											2020-12-09 06:45:25 +00:00
										 |  |  |             poly_lists.append({'points': np.array(poly_list)}) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return poly_lists |