| 
									
										
										
										
											2021-06-04 02:46:45 +00:00
										 |  |  | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #    http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import importlib | 
					
						
							|  |  |  | import copy | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from .rec_metric import RecMetric | 
					
						
							|  |  |  | from .det_metric import DetMetric | 
					
						
							|  |  |  | from .e2e_metric import E2EMetric | 
					
						
							|  |  |  | from .cls_metric import ClsMetric | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DistillationMetric(object): | 
					
						
							|  |  |  |     def __init__(self, | 
					
						
							|  |  |  |                  key=None, | 
					
						
							| 
									
										
										
										
											2021-07-08 14:32:44 +00:00
										 |  |  |                  base_metric_name=None, | 
					
						
							|  |  |  |                  main_indicator=None, | 
					
						
							| 
									
										
										
										
											2021-06-04 02:46:45 +00:00
										 |  |  |                  **kwargs): | 
					
						
							|  |  |  |         self.main_indicator = main_indicator | 
					
						
							|  |  |  |         self.key = key | 
					
						
							|  |  |  |         self.main_indicator = main_indicator | 
					
						
							|  |  |  |         self.base_metric_name = base_metric_name | 
					
						
							|  |  |  |         self.kwargs = kwargs | 
					
						
							|  |  |  |         self.metrics = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _init_metrcis(self, preds): | 
					
						
							|  |  |  |         self.metrics = dict() | 
					
						
							|  |  |  |         mod = importlib.import_module(__name__) | 
					
						
							|  |  |  |         for key in preds: | 
					
						
							|  |  |  |             self.metrics[key] = getattr(mod, self.base_metric_name)( | 
					
						
							|  |  |  |                 main_indicator=self.main_indicator, **self.kwargs) | 
					
						
							|  |  |  |             self.metrics[key].reset() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-07-08 14:32:44 +00:00
										 |  |  |     def __call__(self, preds, batch, **kwargs): | 
					
						
							| 
									
										
										
										
											2021-06-04 02:46:45 +00:00
										 |  |  |         assert isinstance(preds, dict) | 
					
						
							|  |  |  |         if self.metrics is None: | 
					
						
							|  |  |  |             self._init_metrcis(preds) | 
					
						
							|  |  |  |         output = dict() | 
					
						
							|  |  |  |         for key in preds: | 
					
						
							| 
									
										
										
										
											2021-07-08 14:32:44 +00:00
										 |  |  |             self.metrics[key].__call__(preds[key], batch, **kwargs) | 
					
						
							| 
									
										
										
										
											2021-06-04 02:46:45 +00:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def get_metric(self): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         return metrics { | 
					
						
							|  |  |  |                  'acc': 0, | 
					
						
							|  |  |  |                  'norm_edit_dis': 0, | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         output = dict() | 
					
						
							|  |  |  |         for key in self.metrics: | 
					
						
							|  |  |  |             metric = self.metrics[key].get_metric() | 
					
						
							|  |  |  |             # main indicator | 
					
						
							|  |  |  |             if key == self.key: | 
					
						
							|  |  |  |                 output.update(metric) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 for sub_key in metric: | 
					
						
							|  |  |  |                     output["{}_{}".format(key, sub_key)] = metric[sub_key] | 
					
						
							|  |  |  |         return output | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         for key in self.metrics: | 
					
						
							|  |  |  |             self.metrics[key].reset() |