mirror of
				https://github.com/PaddlePaddle/PaddleOCR.git
				synced 2025-10-31 09:49:30 +00:00 
			
		
		
		
	
		
			
	
	
		
			34 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			34 lines
		
	
	
		
			1.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | ||
|  | # | ||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||
|  | # you may not use this file except in compliance with the License. | ||
|  | # You may obtain a copy of the License at | ||
|  | # | ||
|  | #    http://www.apache.org/licenses/LICENSE-2.0 | ||
|  | # | ||
|  | # Unless required by applicable law or agreed to in writing, software | ||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
|  | # See the License for the specific language governing permissions and | ||
|  | # limitations under the License. | ||
|  | import paddle | ||
|  | 
 | ||
|  | 
 | ||
|  | class ClsPostProcess(object): | ||
|  |     """ Convert between text-label and text-index """ | ||
|  | 
 | ||
|  |     def __init__(self, label_list, **kwargs): | ||
|  |         super(ClsPostProcess, self).__init__() | ||
|  |         self.label_list = label_list | ||
|  | 
 | ||
|  |     def __call__(self, preds, label=None, *args, **kwargs): | ||
|  |         if isinstance(preds, paddle.Tensor): | ||
|  |             preds = preds.numpy() | ||
|  |         pred_idxs = preds.argmax(axis=1) | ||
|  |         decode_out = [(self.label_list[idx], preds[i, idx]) | ||
|  |                       for i, idx in enumerate(pred_idxs)] | ||
|  |         if label is None: | ||
|  |             return decode_out | ||
|  |         label = [(self.label_list[idx], 1.0) for idx in label] | ||
|  |         return decode_out, label |