mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-30 17:29:13 +00:00 
			
		
		
		
	 22ccdfbf71
			
		
	
	
		22ccdfbf71
		
			
		
	
	
	
	
		
			
			* [cherry-pick] Fix some typos from release/2.4 branch and merge. * rm readme_en.md in deploy/cpp_infer/. * Fix others typos.
		
			
				
	
	
		
			194 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			194 lines
		
	
	
		
			6.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
 | |
| #
 | |
| # Licensed under the Apache License, Version 2.0 (the "License");
 | |
| # you may not use this file except in compliance with the License.
 | |
| # You may obtain a copy of the License at
 | |
| #
 | |
| #     http://www.apache.org/licenses/LICENSE-2.0
 | |
| #
 | |
| # Unless required by applicable law or agreed to in writing, software
 | |
| # distributed under the License is distributed on an "AS IS" BASIS,
 | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| # See the License for the specific language governing permissions and
 | |
| # limitations under the License.
 | |
| 
 | |
| import os
 | |
| import re
 | |
| import sys
 | |
| import shapely
 | |
| from shapely.geometry import Polygon
 | |
| import numpy as np
 | |
| from collections import defaultdict
 | |
| import operator
 | |
| import editdistance
 | |
| 
 | |
| 
 | |
| def strQ2B(ustring):
 | |
|     rstring = ""
 | |
|     for uchar in ustring:
 | |
|         inside_code = ord(uchar)
 | |
|         if inside_code == 12288:
 | |
|             inside_code = 32
 | |
|         elif (inside_code >= 65281 and inside_code <= 65374):
 | |
|             inside_code -= 65248
 | |
|         rstring += chr(inside_code)
 | |
|     return rstring
 | |
| 
 | |
| 
 | |
| def polygon_from_str(polygon_points):
 | |
|     """
 | |
|     Create a shapely polygon object from gt or dt line.
 | |
|     """
 | |
|     polygon_points = np.array(polygon_points).reshape(4, 2)
 | |
|     polygon = Polygon(polygon_points).convex_hull
 | |
|     return polygon
 | |
| 
 | |
| 
 | |
| def polygon_iou(poly1, poly2):
 | |
|     """
 | |
|     Intersection over union between two shapely polygons.
 | |
|     """
 | |
|     if not poly1.intersects(
 | |
|             poly2):  # this test is fast and can accelerate calculation
 | |
|         iou = 0
 | |
|     else:
 | |
|         try:
 | |
|             inter_area = poly1.intersection(poly2).area
 | |
|             union_area = poly1.area + poly2.area - inter_area
 | |
|             iou = float(inter_area) / union_area
 | |
|         except shapely.geos.TopologicalError:
 | |
|             # except Exception as e:
 | |
|             #     print(e)
 | |
|             print('shapely.geos.TopologicalError occurred, iou set to 0')
 | |
|             iou = 0
 | |
|     return iou
 | |
| 
 | |
| 
 | |
| def ed(str1, str2):
 | |
|     return editdistance.eval(str1, str2)
 | |
| 
 | |
| 
 | |
| def e2e_eval(gt_dir, res_dir, ignore_blank=False):
 | |
|     print('start testing...')
 | |
|     iou_thresh = 0.5
 | |
|     val_names = os.listdir(gt_dir)
 | |
|     num_gt_chars = 0
 | |
|     gt_count = 0
 | |
|     dt_count = 0
 | |
|     hit = 0
 | |
|     ed_sum = 0
 | |
| 
 | |
|     for i, val_name in enumerate(val_names):
 | |
|         with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
 | |
|             gt_lines = [o.strip() for o in f.readlines()]
 | |
|         gts = []
 | |
|         ignore_masks = []
 | |
|         for line in gt_lines:
 | |
|             parts = line.strip().split('\t')
 | |
|             # ignore illegal data
 | |
|             if len(parts) < 9:
 | |
|                 continue
 | |
|             assert (len(parts) < 11)
 | |
|             if len(parts) == 9:
 | |
|                 gts.append(parts[:8] + [''])
 | |
|             else:
 | |
|                 gts.append(parts[:8] + [parts[-1]])
 | |
| 
 | |
|             ignore_masks.append(parts[8])
 | |
| 
 | |
|         val_path = os.path.join(res_dir, val_name)
 | |
|         if not os.path.exists(val_path):
 | |
|             dt_lines = []
 | |
|         else:
 | |
|             with open(val_path, encoding='utf-8') as f:
 | |
|                 dt_lines = [o.strip() for o in f.readlines()]
 | |
|         dts = []
 | |
|         for line in dt_lines:
 | |
|             # print(line)
 | |
|             parts = line.strip().split("\t")
 | |
|             assert (len(parts) < 10), "line error: {}".format(line)
 | |
|             if len(parts) == 8:
 | |
|                 dts.append(parts + [''])
 | |
|             else:
 | |
|                 dts.append(parts)
 | |
| 
 | |
|         dt_match = [False] * len(dts)
 | |
|         gt_match = [False] * len(gts)
 | |
|         all_ious = defaultdict(tuple)
 | |
|         for index_gt, gt in enumerate(gts):
 | |
|             gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
 | |
|             gt_poly = polygon_from_str(gt_coors)
 | |
|             for index_dt, dt in enumerate(dts):
 | |
|                 dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
 | |
|                 dt_poly = polygon_from_str(dt_coors)
 | |
|                 iou = polygon_iou(dt_poly, gt_poly)
 | |
|                 if iou >= iou_thresh:
 | |
|                     all_ious[(index_gt, index_dt)] = iou
 | |
|         sorted_ious = sorted(
 | |
|             all_ious.items(), key=operator.itemgetter(1), reverse=True)
 | |
|         sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
 | |
| 
 | |
|         # matched gt and dt
 | |
|         for gt_dt_pair in sorted_gt_dt_pairs:
 | |
|             index_gt, index_dt = gt_dt_pair
 | |
|             if gt_match[index_gt] == False and dt_match[index_dt] == False:
 | |
|                 gt_match[index_gt] = True
 | |
|                 dt_match[index_dt] = True
 | |
|                 if ignore_blank:
 | |
|                     gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
 | |
|                     dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
 | |
|                 else:
 | |
|                     gt_str = strQ2B(gts[index_gt][8])
 | |
|                     dt_str = strQ2B(dts[index_dt][8])
 | |
|                 if ignore_masks[index_gt] == '0':
 | |
|                     ed_sum += ed(gt_str, dt_str)
 | |
|                     num_gt_chars += len(gt_str)
 | |
|                     if gt_str == dt_str:
 | |
|                         hit += 1
 | |
|                     gt_count += 1
 | |
|                     dt_count += 1
 | |
| 
 | |
|         # unmatched dt
 | |
|         for tindex, dt_match_flag in enumerate(dt_match):
 | |
|             if dt_match_flag == False:
 | |
|                 dt_str = dts[tindex][8]
 | |
|                 gt_str = ''
 | |
|                 ed_sum += ed(dt_str, gt_str)
 | |
|                 dt_count += 1
 | |
| 
 | |
|         # unmatched gt
 | |
|         for tindex, gt_match_flag in enumerate(gt_match):
 | |
|             if gt_match_flag == False and ignore_masks[tindex] == '0':
 | |
|                 dt_str = ''
 | |
|                 gt_str = gts[tindex][8]
 | |
|                 ed_sum += ed(gt_str, dt_str)
 | |
|                 num_gt_chars += len(gt_str)
 | |
|                 gt_count += 1
 | |
| 
 | |
|     eps = 1e-9
 | |
|     print('hit, dt_count, gt_count', hit, dt_count, gt_count)
 | |
|     precision = hit / (dt_count + eps)
 | |
|     recall = hit / (gt_count + eps)
 | |
|     fmeasure = 2.0 * precision * recall / (precision + recall + eps)
 | |
|     avg_edit_dist_img = ed_sum / len(val_names)
 | |
|     avg_edit_dist_field = ed_sum / (gt_count + eps)
 | |
|     character_acc = 1 - ed_sum / (num_gt_chars + eps)
 | |
| 
 | |
|     print('character_acc: %.2f' % (character_acc * 100) + "%")
 | |
|     print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
 | |
|     print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
 | |
|     print('precision: %.2f' % (precision * 100) + "%")
 | |
|     print('recall: %.2f' % (recall * 100) + "%")
 | |
|     print('fmeasure: %.2f' % (fmeasure * 100) + "%")
 | |
| 
 | |
| 
 | |
| if __name__ == '__main__':
 | |
|     # if len(sys.argv) != 3:
 | |
|     #     print("python3 ocr_e2e_eval.py gt_dir res_dir")
 | |
|     #     exit(-1)
 | |
|     # gt_folder = sys.argv[1]
 | |
|     # pred_folder = sys.argv[2]
 | |
|     gt_folder = sys.argv[1]
 | |
|     pred_folder = sys.argv[2]
 | |
|     e2e_eval(gt_folder, pred_folder)
 |