| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | import numpy as np | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | import subprocess | 
					
						
							|  |  |  | import json | 
					
						
							|  |  |  | import argparse | 
					
						
							| 
									
										
										
										
											2021-08-26 03:30:06 +00:00
										 |  |  | import glob | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def init_args(): | 
					
						
							|  |  |  |     parser = argparse.ArgumentParser() | 
					
						
							| 
									
										
										
										
											2021-08-26 03:40:28 +00:00
										 |  |  |     # params for testing assert allclose | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  |     parser.add_argument("--atol", type=float, default=1e-3) | 
					
						
							|  |  |  |     parser.add_argument("--rtol", type=float, default=1e-3) | 
					
						
							|  |  |  |     parser.add_argument("--gt_file", type=str, default="") | 
					
						
							|  |  |  |     parser.add_argument("--log_file", type=str, default="") | 
					
						
							| 
									
										
										
										
											2021-08-26 03:30:06 +00:00
										 |  |  |     parser.add_argument("--precision", type=str, default="fp32") | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  |     return parser | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def parse_args(): | 
					
						
							|  |  |  |     parser = init_args() | 
					
						
							|  |  |  |     return parser.parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def run_shell_command(cmd): | 
					
						
							|  |  |  |     p = subprocess.Popen( | 
					
						
							|  |  |  |         cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True) | 
					
						
							|  |  |  |     out, err = p.communicate() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if p.returncode == 0: | 
					
						
							|  |  |  |         return out.decode('utf-8') | 
					
						
							|  |  |  |     else: | 
					
						
							|  |  |  |         return None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-27 14:10:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | def parser_results_from_log_by_name(log_path, names_list): | 
					
						
							|  |  |  |     if not os.path.exists(log_path): | 
					
						
							|  |  |  |         raise ValueError("The log file {} does not exists!".format(log_path)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     if names_list is None or len(names_list) < 1: | 
					
						
							|  |  |  |         return [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     parser_results = {} | 
					
						
							|  |  |  |     for name in names_list: | 
					
						
							|  |  |  |         cmd = "grep {} {}".format(name, log_path) | 
					
						
							|  |  |  |         outs = run_shell_command(cmd) | 
					
						
							|  |  |  |         outs = outs.split("\n")[0] | 
					
						
							|  |  |  |         result = outs.split("{}".format(name))[-1] | 
					
						
							| 
									
										
										
										
											2021-09-17 08:33:05 +00:00
										 |  |  |         try: | 
					
						
							|  |  |  |             result = json.loads(result) | 
					
						
							|  |  |  |         except: | 
					
						
							|  |  |  |             result = np.array([int(r) for r in result.split()]).reshape(-1, 4) | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  |         parser_results[name] = result | 
					
						
							|  |  |  |     return parser_results | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-10-27 14:10:24 +08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | def load_gt_from_file(gt_file): | 
					
						
							|  |  |  |     if not os.path.exists(gt_file): | 
					
						
							|  |  |  |         raise ValueError("The log file {} does not exists!".format(gt_file)) | 
					
						
							|  |  |  |     with open(gt_file, 'r') as f: | 
					
						
							|  |  |  |         data = f.readlines() | 
					
						
							|  |  |  |         f.close() | 
					
						
							|  |  |  |     parser_gt = {} | 
					
						
							|  |  |  |     for line in data: | 
					
						
							|  |  |  |         image_name, result = line.strip("\n").split("\t") | 
					
						
							| 
									
										
										
										
											2021-09-17 08:33:05 +00:00
										 |  |  |         image_name = image_name.split('/')[-1] | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             result = json.loads(result) | 
					
						
							|  |  |  |         except: | 
					
						
							|  |  |  |             result = np.array([int(r) for r in result.split()]).reshape(-1, 4) | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  |         parser_gt[image_name] = result | 
					
						
							|  |  |  |     return parser_gt | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-26 03:30:06 +00:00
										 |  |  | def load_gt_from_txts(gt_file): | 
					
						
							|  |  |  |     gt_list = glob.glob(gt_file) | 
					
						
							|  |  |  |     gt_collection = {} | 
					
						
							|  |  |  |     for gt_f in gt_list: | 
					
						
							|  |  |  |         gt_dict = load_gt_from_file(gt_f) | 
					
						
							|  |  |  |         basename = os.path.basename(gt_f) | 
					
						
							|  |  |  |         if "fp32" in basename: | 
					
						
							|  |  |  |             gt_collection["fp32"] = [gt_dict, gt_f] | 
					
						
							|  |  |  |         elif "fp16" in basename: | 
					
						
							|  |  |  |             gt_collection["fp16"] = [gt_dict, gt_f] | 
					
						
							|  |  |  |         elif "int8" in basename: | 
					
						
							|  |  |  |             gt_collection["int8"] = [gt_dict, gt_f] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |     return gt_collection | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | def collect_predict_from_logs(log_path, key_list): | 
					
						
							|  |  |  |     log_list = glob.glob(log_path) | 
					
						
							|  |  |  |     pred_collection = {} | 
					
						
							|  |  |  |     for log_f in log_list: | 
					
						
							|  |  |  |         pred_dict = parser_results_from_log_by_name(log_f, key_list) | 
					
						
							|  |  |  |         key = os.path.basename(log_f) | 
					
						
							|  |  |  |         pred_collection[key] = pred_dict | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     return pred_collection | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | def testing_assert_allclose(dict_x, dict_y, atol=1e-7, rtol=1e-7): | 
					
						
							|  |  |  |     for k in dict_x: | 
					
						
							|  |  |  |         np.testing.assert_allclose( | 
					
						
							|  |  |  |             np.array(dict_x[k]), np.array(dict_y[k]), atol=atol, rtol=rtol) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							|  |  |  |     # Usage: | 
					
						
							| 
									
										
										
										
											2021-08-26 03:30:06 +00:00
										 |  |  |     # python3.7 tests/compare_results.py --gt_file=./tests/results/*.txt  --log_file=./tests/output/infer_*.log | 
					
						
							| 
									
										
										
										
											2021-08-25 13:38:20 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     args = parse_args() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-08-26 03:30:06 +00:00
										 |  |  |     gt_collection = load_gt_from_txts(args.gt_file) | 
					
						
							|  |  |  |     key_list = gt_collection["fp32"][0].keys() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     pred_collection = collect_predict_from_logs(args.log_file, key_list) | 
					
						
							|  |  |  |     for filename in pred_collection.keys(): | 
					
						
							|  |  |  |         if "fp32" in filename: | 
					
						
							|  |  |  |             gt_dict, gt_filename = gt_collection["fp32"] | 
					
						
							|  |  |  |         elif "fp16" in filename: | 
					
						
							|  |  |  |             gt_dict, gt_filename = gt_collection["fp16"] | 
					
						
							|  |  |  |         elif "int8" in filename: | 
					
						
							|  |  |  |             gt_dict, gt_filename = gt_collection["int8"] | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             continue | 
					
						
							|  |  |  |         pred_dict = pred_collection[filename] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             testing_assert_allclose( | 
					
						
							|  |  |  |                 gt_dict, pred_dict, atol=args.atol, rtol=args.rtol) | 
					
						
							|  |  |  |             print( | 
					
						
							|  |  |  |                 "Assert allclose passed! The results of {} and {} are consistent!". | 
					
						
							|  |  |  |                 format(filename, gt_filename)) | 
					
						
							|  |  |  |         except Exception as E: | 
					
						
							|  |  |  |             print(E) | 
					
						
							|  |  |  |             raise ValueError( | 
					
						
							|  |  |  |                 "The results of {} and the results of {} are inconsistent!". | 
					
						
							|  |  |  |                 format(filename, gt_filename)) |