mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-04 02:57:34 +00:00
Add Recall Multi Hit and Single Hit metric (#7038)
This commit is contained in:
parent
5910b4adc9
commit
9215882779
@ -1,4 +1,5 @@
|
||||
import collections
|
||||
import itertools
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
@ -16,6 +17,8 @@ class StatisticalMetric(Enum):
|
||||
|
||||
F1 = "f1"
|
||||
EM = "exact_match"
|
||||
RECALL_SINGLE_HIT = "recall_single_hit"
|
||||
RECALL_MULTI_HIT = "recall_multi_hit"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, metric: str) -> "StatisticalMetric":
|
||||
@ -47,7 +50,12 @@ class StatisticalEvaluator:
|
||||
metric = StatisticalMetric.from_str(metric)
|
||||
self._metric = metric
|
||||
|
||||
self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]
|
||||
self._metric_function = {
|
||||
StatisticalMetric.F1: self._f1,
|
||||
StatisticalMetric.EM: self._exact_match,
|
||||
StatisticalMetric.RECALL_SINGLE_HIT: self._recall_single_hit,
|
||||
StatisticalMetric.RECALL_MULTI_HIT: self._recall_multi_hit,
|
||||
}[self._metric]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return default_to_dict(self, metric=self._metric.value)
|
||||
@ -68,9 +76,6 @@ class StatisticalEvaluator:
|
||||
:returns: A dictionary with the following outputs:
|
||||
* `result` - Calculated result of the chosen metric.
|
||||
"""
|
||||
if len(labels) != len(predictions):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
return {"result": self._metric_function(labels, predictions)}
|
||||
|
||||
@staticmethod
|
||||
@ -78,6 +83,9 @@ class StatisticalEvaluator:
|
||||
"""
|
||||
Measure word overlap between predictions and labels.
|
||||
"""
|
||||
if len(labels) != len(predictions):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
if len(predictions) == 0:
|
||||
# We expect callers of this function already checked if predictions and labels are equal length
|
||||
return 0.0
|
||||
@ -105,8 +113,40 @@ class StatisticalEvaluator:
|
||||
"""
|
||||
Measure the proportion of cases where predictiond is identical to the the expected label.
|
||||
"""
|
||||
if len(labels) != len(predictions):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
if len(predictions) == 0:
|
||||
# We expect callers of this function already checked if predictions and labels are equal length
|
||||
return 0.0
|
||||
score_list = np_array(predictions) == np_array(labels)
|
||||
return np_mean(score_list)
|
||||
|
||||
@staticmethod
|
||||
def _recall_single_hit(labels: List[str], predictions: List[str]) -> float:
|
||||
"""
|
||||
Measures how many times a label is present in at least one prediction.
|
||||
If the same label is found in multiple predictions it is only counted once.
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
return 0.0
|
||||
|
||||
# In Recall Single Hit we only consider if a label is present in at least one prediction.
|
||||
# No need to count multiple occurrences of the same label in different predictions
|
||||
retrieved_labels = {l for l, p in itertools.product(labels, predictions) if l in p}
|
||||
return len(retrieved_labels) / len(labels)
|
||||
|
||||
@staticmethod
|
||||
def _recall_multi_hit(labels: List[str], predictions: List[str]) -> float:
|
||||
"""
|
||||
Measures how many times a label is present in at least one or more predictions.
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
return 0.0
|
||||
|
||||
correct_retrievals = 0
|
||||
for label, prediction in itertools.product(labels, predictions):
|
||||
if label in prediction:
|
||||
correct_retrievals += 1
|
||||
|
||||
return correct_retrievals / len(labels)
|
||||
|
||||
@ -121,3 +121,71 @@ class TestStatisticalEvaluatorExactMatch:
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 2 / 3
|
||||
|
||||
|
||||
class TestStatisticalEvaluatorRecallSingleHit:
|
||||
def test_run(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
|
||||
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
|
||||
predictions = [
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Eiffel Tower max height is 330 meters.",
|
||||
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
|
||||
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
|
||||
]
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 2 / 4
|
||||
|
||||
def test_run_with_empty_labels(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
|
||||
predictions = [
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Eiffel Tower max height is 330 meters.",
|
||||
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
|
||||
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
|
||||
]
|
||||
result = evaluator.run(labels=[], predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
def test_run_with_empty_predictions(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_SINGLE_HIT)
|
||||
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
|
||||
result = evaluator.run(labels=labels, predictions=[])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
|
||||
class TestStatisticalEvaluatorRecallMultiHit:
|
||||
def test_run(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
|
||||
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
|
||||
predictions = [
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Eiffel Tower max height is 330 meters.",
|
||||
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
|
||||
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
|
||||
]
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.75
|
||||
|
||||
def test_run_with_empty_labels(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
|
||||
predictions = [
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Eiffel Tower max height is 330 meters.",
|
||||
"Louvre Museum is the world's largest art museum and a historic monument in Paris, France.",
|
||||
"The Leaning Tower of Pisa is the campanile, or freestanding bell tower, of Pisa Cathedral.",
|
||||
]
|
||||
result = evaluator.run(labels=[], predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
def test_run_with_empty_predictions(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.RECALL_MULTI_HIT)
|
||||
labels = ["Eiffel Tower", "Louvre Museum", "Colosseum", "Trajan's Column"]
|
||||
result = evaluator.run(labels=labels, predictions=[])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user