| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Licensed under the Apache License, Version 2.0 (the "License"); | 
					
						
							|  |  |  | # you may not use this file except in compliance with the License. | 
					
						
							|  |  |  | # You may obtain a copy of the License at | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | #     http://www.apache.org/licenses/LICENSE-2.0 | 
					
						
							|  |  |  | # | 
					
						
							|  |  |  | # Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  | # distributed under the License is distributed on an "AS IS" BASIS, | 
					
						
							|  |  |  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
					
						
							|  |  |  | # See the License for the specific language governing permissions and | 
					
						
							|  |  |  | # limitations under the License. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from __future__ import absolute_import | 
					
						
							|  |  |  | from __future__ import division | 
					
						
							|  |  |  | from __future__ import print_function | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy as np | 
					
						
							|  |  |  | import paddle | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | __all__ = ['KIEMetric'] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class VQAReTokenMetric(object): | 
					
						
							|  |  |  |     def __init__(self, main_indicator='hmean', **kwargs): | 
					
						
							|  |  |  |         self.main_indicator = main_indicator | 
					
						
							|  |  |  |         self.reset() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __call__(self, preds, batch, **kwargs): | 
					
						
							|  |  |  |         pred_relations, relations, entities = preds | 
					
						
							|  |  |  |         self.pred_relations_list.extend(pred_relations) | 
					
						
							|  |  |  |         self.relations_list.extend(relations) | 
					
						
							|  |  |  |         self.entities_list.extend(entities) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_metric(self): | 
					
						
							|  |  |  |         gt_relations = [] | 
					
						
							|  |  |  |         for b in range(len(self.relations_list)): | 
					
						
							|  |  |  |             rel_sent = [] | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |             relation_list = self.relations_list[b] | 
					
						
							|  |  |  |             entitie_list = self.entities_list[b] | 
					
						
							|  |  |  |             head_len = relation_list[0, 0] | 
					
						
							|  |  |  |             if head_len > 0: | 
					
						
							|  |  |  |                 entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0] | 
					
						
							|  |  |  |                 entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1] | 
					
						
							|  |  |  |                 entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2] | 
					
						
							|  |  |  |                 for head, tail in zip(relation_list[1:head_len + 1, 0], | 
					
						
							|  |  |  |                                       relation_list[1:head_len + 1, 1]): | 
					
						
							| 
									
										
										
										
											2022-07-06 13:58:46 +08:00
										 |  |  |                     rel = {} | 
					
						
							|  |  |  |                     rel["head_id"] = head | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |                     rel["head"] = (entitie_start_list[head], | 
					
						
							|  |  |  |                                    entitie_end_list[head]) | 
					
						
							|  |  |  |                     rel["head_type"] = entitie_label_list[head] | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-06 13:58:46 +08:00
										 |  |  |                     rel["tail_id"] = tail | 
					
						
							| 
									
										
										
										
											2022-09-20 22:13:27 +08:00
										 |  |  |                     rel["tail"] = (entitie_start_list[tail], | 
					
						
							|  |  |  |                                    entitie_end_list[tail]) | 
					
						
							|  |  |  |                     rel["tail_type"] = entitie_label_list[tail] | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2022-07-06 13:58:46 +08:00
										 |  |  |                     rel["type"] = 1 | 
					
						
							|  |  |  |                     rel_sent.append(rel) | 
					
						
							| 
									
										
										
										
											2022-01-05 11:03:45 +00:00
										 |  |  |             gt_relations.append(rel_sent) | 
					
						
							|  |  |  |         re_metrics = self.re_score( | 
					
						
							|  |  |  |             self.pred_relations_list, gt_relations, mode="boundaries") | 
					
						
							|  |  |  |         metrics = { | 
					
						
							|  |  |  |             "precision": re_metrics["ALL"]["p"], | 
					
						
							|  |  |  |             "recall": re_metrics["ALL"]["r"], | 
					
						
							|  |  |  |             "hmean": re_metrics["ALL"]["f1"], | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |         self.reset() | 
					
						
							|  |  |  |         return metrics | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def reset(self): | 
					
						
							|  |  |  |         self.pred_relations_list = [] | 
					
						
							|  |  |  |         self.relations_list = [] | 
					
						
							|  |  |  |         self.entities_list = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def re_score(self, pred_relations, gt_relations, mode="strict"): | 
					
						
							|  |  |  |         """Evaluate RE predictions
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             pred_relations (list) :  list of list of predicted relations (several relations in each sentence) | 
					
						
							|  |  |  |             gt_relations (list) :    list of list of ground truth relations | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 rel = { "head": (start_idx (inclusive), end_idx (exclusive)), | 
					
						
							|  |  |  |                         "tail": (start_idx (inclusive), end_idx (exclusive)), | 
					
						
							|  |  |  |                         "head_type": ent_type, | 
					
						
							|  |  |  |                         "tail_type": ent_type, | 
					
						
							|  |  |  |                         "type": rel_type} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             vocab (Vocab) :         dataset vocabulary | 
					
						
							|  |  |  |             mode (str) :            in 'strict' or 'boundaries'"""
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert mode in ["strict", "boundaries"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         relation_types = [v for v in [0, 1] if not v == 0] | 
					
						
							|  |  |  |         scores = { | 
					
						
							|  |  |  |             rel: { | 
					
						
							|  |  |  |                 "tp": 0, | 
					
						
							|  |  |  |                 "fp": 0, | 
					
						
							|  |  |  |                 "fn": 0 | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |             for rel in relation_types + ["ALL"] | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Count GT relations and Predicted relations | 
					
						
							|  |  |  |         n_sents = len(gt_relations) | 
					
						
							|  |  |  |         n_rels = sum([len([rel for rel in sent]) for sent in gt_relations]) | 
					
						
							|  |  |  |         n_found = sum([len([rel for rel in sent]) for sent in pred_relations]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Count TP, FP and FN per type | 
					
						
							|  |  |  |         for pred_sent, gt_sent in zip(pred_relations, gt_relations): | 
					
						
							|  |  |  |             for rel_type in relation_types: | 
					
						
							|  |  |  |                 # strict mode takes argument types into account | 
					
						
							|  |  |  |                 if mode == "strict": | 
					
						
							|  |  |  |                     pred_rels = {(rel["head"], rel["head_type"], rel["tail"], | 
					
						
							|  |  |  |                                   rel["tail_type"]) | 
					
						
							|  |  |  |                                  for rel in pred_sent | 
					
						
							|  |  |  |                                  if rel["type"] == rel_type} | 
					
						
							|  |  |  |                     gt_rels = {(rel["head"], rel["head_type"], rel["tail"], | 
					
						
							|  |  |  |                                 rel["tail_type"]) | 
					
						
							|  |  |  |                                for rel in gt_sent if rel["type"] == rel_type} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # boundaries mode only takes argument spans into account | 
					
						
							|  |  |  |                 elif mode == "boundaries": | 
					
						
							|  |  |  |                     pred_rels = {(rel["head"], rel["tail"]) | 
					
						
							|  |  |  |                                  for rel in pred_sent | 
					
						
							|  |  |  |                                  if rel["type"] == rel_type} | 
					
						
							|  |  |  |                     gt_rels = {(rel["head"], rel["tail"]) | 
					
						
							|  |  |  |                                for rel in gt_sent if rel["type"] == rel_type} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 scores[rel_type]["tp"] += len(pred_rels & gt_rels) | 
					
						
							|  |  |  |                 scores[rel_type]["fp"] += len(pred_rels - gt_rels) | 
					
						
							|  |  |  |                 scores[rel_type]["fn"] += len(gt_rels - pred_rels) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Compute per entity Precision / Recall / F1 | 
					
						
							|  |  |  |         for rel_type in scores.keys(): | 
					
						
							|  |  |  |             if scores[rel_type]["tp"]: | 
					
						
							|  |  |  |                 scores[rel_type]["p"] = scores[rel_type]["tp"] / ( | 
					
						
							|  |  |  |                     scores[rel_type]["fp"] + scores[rel_type]["tp"]) | 
					
						
							|  |  |  |                 scores[rel_type]["r"] = scores[rel_type]["tp"] / ( | 
					
						
							|  |  |  |                     scores[rel_type]["fn"] + scores[rel_type]["tp"]) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0: | 
					
						
							|  |  |  |                 scores[rel_type]["f1"] = ( | 
					
						
							|  |  |  |                     2 * scores[rel_type]["p"] * scores[rel_type]["r"] / | 
					
						
							|  |  |  |                     (scores[rel_type]["p"] + scores[rel_type]["r"])) | 
					
						
							|  |  |  |             else: | 
					
						
							|  |  |  |                 scores[rel_type]["f1"] = 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Compute micro F1 Scores | 
					
						
							|  |  |  |         tp = sum([scores[rel_type]["tp"] for rel_type in relation_types]) | 
					
						
							|  |  |  |         fp = sum([scores[rel_type]["fp"] for rel_type in relation_types]) | 
					
						
							|  |  |  |         fn = sum([scores[rel_type]["fn"] for rel_type in relation_types]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if tp: | 
					
						
							|  |  |  |             precision = tp / (tp + fp) | 
					
						
							|  |  |  |             recall = tp / (tp + fn) | 
					
						
							|  |  |  |             f1 = 2 * precision * recall / (precision + recall) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             precision, recall, f1 = 0, 0, 0 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         scores["ALL"]["p"] = precision | 
					
						
							|  |  |  |         scores["ALL"]["r"] = recall | 
					
						
							|  |  |  |         scores["ALL"]["f1"] = f1 | 
					
						
							|  |  |  |         scores["ALL"]["tp"] = tp | 
					
						
							|  |  |  |         scores["ALL"]["fp"] = fp | 
					
						
							|  |  |  |         scores["ALL"]["fn"] = fn | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Compute Macro F1 Scores | 
					
						
							|  |  |  |         scores["ALL"]["Macro_f1"] = np.mean( | 
					
						
							|  |  |  |             [scores[ent_type]["f1"] for ent_type in relation_types]) | 
					
						
							|  |  |  |         scores["ALL"]["Macro_p"] = np.mean( | 
					
						
							|  |  |  |             [scores[ent_type]["p"] for ent_type in relation_types]) | 
					
						
							|  |  |  |         scores["ALL"]["Macro_r"] = np.mean( | 
					
						
							|  |  |  |             [scores[ent_type]["r"] for ent_type in relation_types]) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return scores |