mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-01 18:29:32 +00:00
refactor: Refactor StatisticalEvaluator (#6999)
* Refactor StatisticalEvaluator * Update StatisticalEvaluator * Rename StatisticalMetric.from_string to from_str and change internal logic Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Fix tests --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
c82f787b41
commit
2b8a606cb8
@ -1,4 +1,4 @@
|
||||
from .sas_evaluator import SASEvaluator
|
||||
from .statistical_evaluator import StatisticalEvaluator
|
||||
from .statistical_evaluator import StatisticalEvaluator, StatisticalMetric
|
||||
|
||||
__all__ = ["SASEvaluator", "StatisticalEvaluator"]
|
||||
__all__ = ["SASEvaluator", "StatisticalEvaluator", "StatisticalMetric"]
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import collections
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from numpy import array as np_array
|
||||
from numpy import mean as np_mean
|
||||
@ -8,7 +8,22 @@ from numpy import mean as np_mean
|
||||
from haystack import default_from_dict, default_to_dict
|
||||
from haystack.core.component import component
|
||||
|
||||
from .preprocess import _preprocess_text
|
||||
|
||||
class StatisticalMetric(Enum):
|
||||
"""
|
||||
Metrics supported by the StatisticalEvaluator.
|
||||
"""
|
||||
|
||||
F1 = "f1"
|
||||
EM = "exact_match"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, metric: str) -> "StatisticalMetric":
|
||||
map = {e.value: e for e in StatisticalMetric}
|
||||
metric_ = map.get(metric)
|
||||
if metric_ is None:
|
||||
raise ValueError(f"Unknown statistical metric '{metric}'")
|
||||
return metric_
|
||||
|
||||
|
||||
@component
|
||||
@ -22,79 +37,44 @@ class StatisticalEvaluator:
|
||||
- Exact Match: Measures the proportion of cases where prediction is identical to the expected label.
|
||||
"""
|
||||
|
||||
class Metric(Enum):
|
||||
"""
|
||||
Supported metrics
|
||||
"""
|
||||
|
||||
F1 = "F1"
|
||||
EM = "Exact Match"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
labels: List[str],
|
||||
metric: Metric,
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
):
|
||||
def __init__(self, metric: Union[str, StatisticalMetric]):
|
||||
"""
|
||||
Creates a new instance of StatisticalEvaluator.
|
||||
|
||||
:param labels: The list of expected answers.
|
||||
:param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match.
|
||||
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
|
||||
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
||||
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
|
||||
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
|
||||
comparison. Defaults to False.
|
||||
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
|
||||
before comparison. Defaults to False.
|
||||
"""
|
||||
self._labels = labels
|
||||
if isinstance(metric, str):
|
||||
metric = StatisticalMetric.from_str(metric)
|
||||
self._metric = metric
|
||||
self._regexes_to_ignore = regexes_to_ignore
|
||||
self._ignore_case = ignore_case
|
||||
self._ignore_punctuation = ignore_punctuation
|
||||
self._ignore_numbers = ignore_numbers
|
||||
|
||||
self._metric_function = {
|
||||
StatisticalEvaluator.Metric.F1: self._f1,
|
||||
StatisticalEvaluator.Metric.EM: self._exact_match,
|
||||
}[self._metric]
|
||||
self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return default_to_dict(
|
||||
self,
|
||||
labels=self._labels,
|
||||
metric=self._metric.value,
|
||||
regexes_to_ignore=self._regexes_to_ignore,
|
||||
ignore_case=self._ignore_case,
|
||||
ignore_punctuation=self._ignore_punctuation,
|
||||
ignore_numbers=self._ignore_numbers,
|
||||
)
|
||||
return default_to_dict(self, metric=self._metric.value)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator":
|
||||
data["init_parameters"]["metric"] = StatisticalEvaluator.Metric(data["init_parameters"]["metric"])
|
||||
data["init_parameters"]["metric"] = StatisticalMetric(data["init_parameters"]["metric"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(result=float)
|
||||
def run(self, predictions: List[str]) -> Dict[str, Any]:
|
||||
if len(predictions) != len(self._labels):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the StatisticalEvaluator to compute the metric between a list of predictions and a list of labels.
|
||||
Both must be list of strings of same length.
|
||||
|
||||
predictions = _preprocess_text(
|
||||
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
||||
)
|
||||
labels = _preprocess_text(
|
||||
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
|
||||
)
|
||||
:param predictions: List of predictions.
|
||||
:param labels: List of labels against which the predictions are compared.
|
||||
:returns: A dictionary with the following outputs:
|
||||
* `result` - Calculated result of the chosen metric.
|
||||
"""
|
||||
if len(labels) != len(predictions):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
return {"result": self._metric_function(labels, predictions)}
|
||||
|
||||
def _f1(self, labels: List[str], predictions: List[str]):
|
||||
@staticmethod
|
||||
def _f1(labels: List[str], predictions: List[str]):
|
||||
"""
|
||||
Measure word overlap between predictions and labels.
|
||||
"""
|
||||
@ -120,7 +100,8 @@ class StatisticalEvaluator:
|
||||
|
||||
return np_mean(scores)
|
||||
|
||||
def _exact_match(self, labels: List[str], predictions: List[str]) -> float:
|
||||
@staticmethod
|
||||
def _exact_match(labels: List[str], predictions: List[str]) -> float:
|
||||
"""
|
||||
Measure the proportion of cases where predictiond is identical to the the expected label.
|
||||
"""
|
||||
|
||||
@ -1,33 +1,23 @@
|
||||
import pytest
|
||||
|
||||
from haystack.components.eval import StatisticalEvaluator
|
||||
from haystack.components.eval import StatisticalEvaluator, StatisticalMetric
|
||||
|
||||
|
||||
class TestStatisticalEvaluator:
|
||||
def test_init_default(self):
|
||||
labels = ["label1", "label2", "label3"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
assert evaluator._labels == labels
|
||||
assert evaluator._metric == StatisticalEvaluator.Metric.F1
|
||||
assert evaluator._regexes_to_ignore is None
|
||||
assert evaluator._ignore_case is False
|
||||
assert evaluator._ignore_punctuation is False
|
||||
assert evaluator._ignore_numbers is False
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
assert evaluator._metric == StatisticalMetric.F1
|
||||
|
||||
def test_init_with_string(self):
|
||||
evaluator = StatisticalEvaluator(metric="exact_match")
|
||||
assert evaluator._metric == StatisticalMetric.EM
|
||||
|
||||
def test_to_dict(self):
|
||||
labels = ["label1", "label2", "label3"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
|
||||
expected_dict = {
|
||||
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
|
||||
"init_parameters": {
|
||||
"labels": labels,
|
||||
"metric": "F1",
|
||||
"regexes_to_ignore": None,
|
||||
"ignore_case": False,
|
||||
"ignore_punctuation": False,
|
||||
"ignore_numbers": False,
|
||||
},
|
||||
"init_parameters": {"metric": "f1"},
|
||||
}
|
||||
assert evaluator.to_dict() == expected_dict
|
||||
|
||||
@ -35,225 +25,99 @@ class TestStatisticalEvaluator:
|
||||
evaluator = StatisticalEvaluator.from_dict(
|
||||
{
|
||||
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
|
||||
"init_parameters": {
|
||||
"labels": ["label1", "label2", "label3"],
|
||||
"metric": "F1",
|
||||
"regexes_to_ignore": None,
|
||||
"ignore_case": False,
|
||||
"ignore_punctuation": False,
|
||||
"ignore_numbers": False,
|
||||
},
|
||||
"init_parameters": {"metric": "f1"},
|
||||
}
|
||||
)
|
||||
|
||||
assert evaluator._labels == ["label1", "label2", "label3"]
|
||||
assert evaluator._metric == StatisticalEvaluator.Metric.F1
|
||||
assert evaluator._regexes_to_ignore is None
|
||||
assert evaluator._ignore_case is False
|
||||
assert evaluator._ignore_punctuation is False
|
||||
assert evaluator._ignore_numbers is False
|
||||
assert evaluator._metric == StatisticalMetric.F1
|
||||
|
||||
|
||||
class TestStatisticalEvaluatorF1:
|
||||
def test_run_with_empty_inputs(self):
|
||||
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.F1)
|
||||
result = evaluator.run(predictions=[])
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
result = evaluator.run(labels=[], predictions=[])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
def test_run_with_different_lengths(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
evaluator.run(predictions)
|
||||
evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
def test_run_with_matching_predictions(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_single_prediction(self):
|
||||
labels = ["Source"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
|
||||
result = evaluator.run(predictions=["Open Source"])
|
||||
result = evaluator.run(labels=["Source"], predictions=["Open Source"])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == pytest.approx(2 / 3)
|
||||
|
||||
def test_run_with_mismatched_predictions(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
||||
predictions = ["Open Source", "HaystackAI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_ignore_case(self):
|
||||
labels = ["source", "HAYSTACKAI"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_case=True)
|
||||
predictions = ["Open Source", "HaystackAI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_ignore_punctuation(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_punctuation=True)
|
||||
predictions = ["Open Source!", "Haystack.AI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_ignore_numbers(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_numbers=True)
|
||||
predictions = ["Open Source123", "HaystackAI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_regex_to_ignore(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+"]
|
||||
)
|
||||
predictions = ["Open123 Source", "HaystackAI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_multiple_regex_to_ignore(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+", r"[^\w\s]"]
|
||||
)
|
||||
predictions = ["Open123! Source", "Haystack.AI"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
def test_run_with_multiple_ignore_parameters(self):
|
||||
labels = ["Source", "HaystackAI"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels,
|
||||
metric=StatisticalEvaluator.Metric.F1,
|
||||
ignore_numbers=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_case=True,
|
||||
regexes_to_ignore=[r"[^\w\s\d]+"],
|
||||
)
|
||||
predictions = ["Open%123. !$Source", "Haystack.AI##"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == pytest.approx(5 / 6)
|
||||
|
||||
|
||||
class TestStatisticalEvaluatorExactMatch:
|
||||
def test_run_with_empty_inputs(self):
|
||||
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.EM)
|
||||
result = evaluator.run(predictions=[])
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
||||
result = evaluator.run(predictions=[], labels=[])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 0.0
|
||||
|
||||
def test_run_with_different_lengths(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
|
||||
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
evaluator.run(predictions)
|
||||
evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
def test_run_with_matching_predictions(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_single_prediction(self):
|
||||
labels = ["OpenSource"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
|
||||
|
||||
result = evaluator.run(predictions=["OpenSource"])
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
||||
result = evaluator.run(labels=["OpenSource"], predictions=["OpenSource"])
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_mismatched_predictions(self):
|
||||
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
||||
labels = ["Source", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
result = evaluator.run(labels=labels, predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 2 / 3
|
||||
|
||||
def test_run_with_ignore_case(self):
|
||||
labels = ["opensource", "HAYSTACKAI", "llMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_case=True)
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert len(result) == 1
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_ignore_punctuation(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_punctuation=True)
|
||||
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_ignore_numbers(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_numbers=True)
|
||||
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_regex_to_ignore(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+"]
|
||||
)
|
||||
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_multiple_regex_to_ignore(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+", r"\W+"]
|
||||
)
|
||||
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == 1.0
|
||||
|
||||
def test_run_with_multiple_ignore_parameters(self):
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluator = StatisticalEvaluator(
|
||||
labels=labels,
|
||||
metric=StatisticalEvaluator.Metric.EM,
|
||||
ignore_numbers=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_case=True,
|
||||
regexes_to_ignore=[r"[^\w\s\d]+"],
|
||||
)
|
||||
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
|
||||
result = evaluator.run(predictions=predictions)
|
||||
assert result["result"] == 1.0
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user