mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 01:39:11 +00:00 
			
		
		
		
	fix infer
This commit is contained in:
		
							parent
							
								
									50020a8ef9
								
							
						
					
					
						commit
						c342b7a013
					
				| @ -47,9 +47,9 @@ def read_class_list(filepath): | |||||||
|     return dict |     return dict | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def draw_kie_result(batch, node, idx_to_cls): | def draw_kie_result(batch, node, idx_to_cls, count): | ||||||
|     img = batch[-2] |     img = batch[6].copy() | ||||||
|     boxes = batch[-1] |     boxes = batch[7] | ||||||
|     h, w = img.shape[:2] |     h, w = img.shape[:2] | ||||||
|     pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 |     pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255 | ||||||
|     max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) |     max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1) | ||||||
| @ -77,11 +77,15 @@ def draw_kie_result(batch, node, idx_to_cls): | |||||||
|         text = pred_label + '(' + pred_score + ')' |         text = pred_label + '(' + pred_score + ')' | ||||||
|         cv2.putText(pred_img, text, (x_min * 2, y_min), |         cv2.putText(pred_img, text, (x_min * 2, y_min), | ||||||
|                     cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) |                     cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1) | ||||||
| 
 |  | ||||||
|     vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 |     vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255 | ||||||
|     vis_img[:, :w] = img |     vis_img[:, :w] = img | ||||||
|     vis_img[:, w:] = pred_img |     vis_img[:, w:] = pred_img | ||||||
|     return vis_img |     save_kie_path = os.path.dirname(config['Global']['save_res_path']) + "/kie_results/" | ||||||
|  |     if not os.path.exists(save_kie_path): | ||||||
|  |         os.makedirs(save_kie_path) | ||||||
|  |     save_path = os.path.join(save_kie_path, str(count) + ".png") | ||||||
|  |     cv2.imwrite(save_path, vis_img) | ||||||
|  |     logger.info("The Kie Image saved in {}".format(save_path)) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def main(): | def main(): | ||||||
| @ -89,7 +93,6 @@ def main(): | |||||||
| 
 | 
 | ||||||
|     # build model |     # build model | ||||||
|     model = build_model(config['Architecture']) |     model = build_model(config['Architecture']) | ||||||
| 
 |  | ||||||
|     init_model(config, model, logger) |     init_model(config, model, logger) | ||||||
| 
 | 
 | ||||||
|     # create data ops |     # create data ops | ||||||
| @ -97,6 +100,8 @@ def main(): | |||||||
|     for op in config['Eval']['dataset']['transforms']: |     for op in config['Eval']['dataset']['transforms']: | ||||||
|         transforms.append(op) |         transforms.append(op) | ||||||
| 
 | 
 | ||||||
|  |     data_dir = config['Eval']['dataset']['data_dir'] | ||||||
|  | 
 | ||||||
|     ops = create_operators(transforms, global_config) |     ops = create_operators(transforms, global_config) | ||||||
| 
 | 
 | ||||||
|     save_res_path = config['Global']['save_res_path'] |     save_res_path = config['Global']['save_res_path'] | ||||||
| @ -109,11 +114,10 @@ def main(): | |||||||
|     with open(save_res_path, "wb") as fout: |     with open(save_res_path, "wb") as fout: | ||||||
|         with open(config['Global']['infer_img'], "rb") as f: |         with open(config['Global']['infer_img'], "rb") as f: | ||||||
|             lines = f.readlines() |             lines = f.readlines() | ||||||
|             for data_line in lines: |             for index, data_line in enumerate(lines): | ||||||
|                 data_line = data_line.decode('utf-8') |                 data_line = data_line.decode('utf-8') | ||||||
|                 substr = data_line.strip("\n").split("\t") |                 substr = data_line.strip("\n").split("\t") | ||||||
|                 img_path, label = "/Users/hongyongjie/project/PaddleOCR/train_data/wildreceipt/" + substr[ |                 img_path, label = data_dir + "/" + substr[0], substr[1] | ||||||
|                     0], substr[1] |  | ||||||
|                 data = {'img_path': img_path, 'label': label} |                 data = {'img_path': img_path, 'label': label} | ||||||
|                 with open(data['img_path'], 'rb') as f: |                 with open(data['img_path'], 'rb') as f: | ||||||
|                     img = f.read() |                     img = f.read() | ||||||
| @ -126,9 +130,7 @@ def main(): | |||||||
|                             batch[i], axis=0)) |                             batch[i], axis=0)) | ||||||
|                 node, edge = model(batch_pred) |                 node, edge = model(batch_pred) | ||||||
|                 node = F.softmax(node, -1) |                 node = F.softmax(node, -1) | ||||||
|                 img = draw_kie_result(batch, node, idx_to_cls) |                 draw_kie_result(batch, node, idx_to_cls, index) | ||||||
|                 cv2.imwrite('1.png', img) |  | ||||||
|                 exit() |  | ||||||
|     logger.info("success!") |     logger.info("success!") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 LDOUBLEV
						LDOUBLEV