Add Recall Multi Hit and Single Hit metric (#7038)

This commit is contained in:
Silvano Cerza 2024-02-19 18:00:39 +01:00 committed by GitHub
parent 5910b4adc9
commit 9215882779
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 112 additions and 4 deletions

View File

@ -1,4 +1,5 @@
import collections
import itertools
from enum import Enum
from typing import Any, Dict, List, Union
@ -16,6 +17,8 @@ class StatisticalMetric(Enum):
F1 = "f1"
EM = "exact_match"
RECALL_SINGLE_HIT = "recall_single_hit"
RECALL_MULTI_HIT = "recall_multi_hit"
@classmethod
def from_str(cls, metric: str) -> "StatisticalMetric":
@ -47,7 +50,12 @@ class StatisticalEvaluator:
metric = StatisticalMetric.from_str(metric)
self._metric = metric
self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]
self._metric_function = {
StatisticalMetric.F1: self._f1,
StatisticalMetric.EM: self._exact_match,
StatisticalMetric.RECALL_SINGLE_HIT: self._recall_single_hit,
StatisticalMetric.RECALL_MULTI_HIT: self._recall_multi_hit,
}[self._metric]
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(self, metric=self._metric.value)
@ -68,9 +76,6 @@ class StatisticalEvaluator:
:returns: A dictionary with the following outputs:
* `result` - Calculated result of the chosen metric.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
return {"result": self._metric_function(labels, predictions)}
@staticmethod
@ -78,6 +83,9 @@ class StatisticalEvaluator:
"""
Measure word overlap between predictions and labels.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == 0:
# We expect callers of this function already checked if predictions and labels are equal length
return 0.0
@ -105,8 +113,40 @@ class StatisticalEvaluator:
"""
Measure the proportion of cases where predictiond is identical to the the expected label.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == 0:
# We expect callers of this function already checked if predictions and labels are equal length
return 0.0
score_list = np_array(predictions) == np_array(labels)
return np_mean(score_list)
@staticmethod
def _recall_single_hit(labels: List[str], predictions: List[str]) -> float:
"""
Measures how many times a label is present in at least one prediction.
If the same label is found in multiple predictions it is only counted once.
"""
if len(labels) == 0:
return 0.0
# In Recall Single Hit we only consider if a label is present in at least one prediction.
# No need to count multiple occurrences of the same label in different predictions
retrieved_labels = {l for l, p in itertools.product(labels, predictions) if l in p}
return len(retrieved_labels) / len(labels)
@staticmethod
def _recall_multi_hit(labels: List[str], predictions: List[str]) -> float:
"""
Measures how many times a label is present in at least one or more predictions.
"""
if len(labels) == 0:
return 0.0
correct_retrievals = 0
for label, prediction in itertools.product(labels, predictions):
if label in prediction:
correct_retrievals += 1
return correct_retrievals / len(labels)

View File

@ -121,3 +121,71 @@ class TestStatisticalEvaluatorExactMatch:
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 2 / 3
class TestStatisticalEvaluatorRecallSingleHit:
def test_run(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
predictions = [
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Eiffel Tower max height is 330 meters.",
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
]
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 2 / 4
def test_run_with_empty_labels(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
predictions = [
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Eiffel Tower max height is 330 meters.",
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
]
result = evaluator.run(labels=[], predictions=predictions)
assert len(result) == 1
assert result["result"] == 0.0
def test_run_with_empty_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
result = evaluator.run(labels=labels, predictions=[])
assert len(result) == 1
assert result["result"] == 0.0
class TestStatisticalEvaluatorRecallMultiHit:
def test_run(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
predictions = [
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Eiffel Tower max height is 330 meters.",
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
]
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 0.75
def test_run_with_empty_labels(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
predictions = [
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Eiffel Tower max height is 330 meters.",
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
]
result = evaluator.run(labels=[], predictions=predictions)
assert len(result) == 1
assert result["result"] == 0.0
def test_run_with_empty_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
result = evaluator.run(labels=labels, predictions=[])
assert len(result) == 1
assert result["result"] == 0.0