refactor: Refactor StatisticalEvaluator (#6999)

* Refactor StatisticalEvaluator

* Update StatisticalEvaluator

* Rename StatisticalMetric.from_string to from_str and change internal logic

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Fix tests

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Silvano Cerza 2024-02-15 16:47:35 +01:00 committed by GitHub
parent c82f787b41
commit 2b8a606cb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 226 deletions

View File

@ -1,4 +1,4 @@
from .sas_evaluator import SASEvaluator
from .statistical_evaluator import StatisticalEvaluator
from .statistical_evaluator import StatisticalEvaluator, StatisticalMetric
__all__ = ["SASEvaluator", "StatisticalEvaluator"]
__all__ = ["SASEvaluator", "StatisticalEvaluator", "StatisticalMetric"]

View File

@ -1,6 +1,6 @@
import collections
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Union
from numpy import array as np_array
from numpy import mean as np_mean
@ -8,7 +8,22 @@ from numpy import mean as np_mean
from haystack import default_from_dict, default_to_dict
from haystack.core.component import component
from .preprocess import _preprocess_text
class StatisticalMetric(Enum):
"""
Metrics supported by the StatisticalEvaluator.
"""
F1 = "f1"
EM = "exact_match"
@classmethod
def from_str(cls, metric: str) -> "StatisticalMetric":
map = {e.value: e for e in StatisticalMetric}
metric_ = map.get(metric)
if metric_ is None:
raise ValueError(f"Unknown statistical metric '{metric}'")
return metric_
@component
@ -22,79 +37,44 @@ class StatisticalEvaluator:
- Exact Match: Measures the proportion of cases where prediction is identical to the expected label.
"""
class Metric(Enum):
"""
Supported metrics
"""
F1 = "F1"
EM = "Exact Match"
def __init__(
self,
labels: List[str],
metric: Metric,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
):
def __init__(self, metric: Union[str, StatisticalMetric]):
"""
Creates a new instance of StatisticalEvaluator.
:param labels: The list of expected answers.
:param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match.
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
"""
self._labels = labels
if isinstance(metric, str):
metric = StatisticalMetric.from_str(metric)
self._metric = metric
self._regexes_to_ignore = regexes_to_ignore
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
self._ignore_numbers = ignore_numbers
self._metric_function = {
StatisticalEvaluator.Metric.F1: self._f1,
StatisticalEvaluator.Metric.EM: self._exact_match,
}[self._metric]
self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]
def to_dict(self) -> Dict[str, Any]:
return default_to_dict(
self,
labels=self._labels,
metric=self._metric.value,
regexes_to_ignore=self._regexes_to_ignore,
ignore_case=self._ignore_case,
ignore_punctuation=self._ignore_punctuation,
ignore_numbers=self._ignore_numbers,
)
return default_to_dict(self, metric=self._metric.value)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator":
data["init_parameters"]["metric"] = StatisticalEvaluator.Metric(data["init_parameters"]["metric"])
data["init_parameters"]["metric"] = StatisticalMetric(data["init_parameters"]["metric"])
return default_from_dict(cls, data)
@component.output_types(result=float)
def run(self, predictions: List[str]) -> Dict[str, Any]:
if len(predictions) != len(self._labels):
raise ValueError("The number of predictions and labels must be the same.")
def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
"""
Run the StatisticalEvaluator to compute the metric between a list of predictions and a list of labels.
Both must be list of strings of same length.
predictions = _preprocess_text(
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)
labels = _preprocess_text(
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers
)
:param predictions: List of predictions.
:param labels: List of labels against which the predictions are compared.
:returns: A dictionary with the following outputs:
* `result` - Calculated result of the chosen metric.
"""
if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
return {"result": self._metric_function(labels, predictions)}
def _f1(self, labels: List[str], predictions: List[str]):
@staticmethod
def _f1(labels: List[str], predictions: List[str]):
"""
Measure word overlap between predictions and labels.
"""
@ -120,7 +100,8 @@ class StatisticalEvaluator:
return np_mean(scores)
def _exact_match(self, labels: List[str], predictions: List[str]) -> float:
@staticmethod
def _exact_match(labels: List[str], predictions: List[str]) -> float:
"""
Measure the proportion of cases where predictiond is identical to the the expected label.
"""

View File

@ -1,33 +1,23 @@
import pytest
from haystack.components.eval import StatisticalEvaluator
from haystack.components.eval import StatisticalEvaluator, StatisticalMetric
class TestStatisticalEvaluator:
def test_init_default(self):
labels = ["label1", "label2", "label3"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
assert evaluator._labels == labels
assert evaluator._metric == StatisticalEvaluator.Metric.F1
assert evaluator._regexes_to_ignore is None
assert evaluator._ignore_case is False
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
assert evaluator._metric == StatisticalMetric.F1
def test_init_with_string(self):
evaluator = StatisticalEvaluator(metric="exact_match")
assert evaluator._metric == StatisticalMetric.EM
def test_to_dict(self):
labels = ["label1", "label2", "label3"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
expected_dict = {
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": {
"labels": labels,
"metric": "F1",
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
},
"init_parameters": {"metric": "f1"},
}
assert evaluator.to_dict() == expected_dict
@ -35,225 +25,99 @@ class TestStatisticalEvaluator:
evaluator = StatisticalEvaluator.from_dict(
{
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": {
"labels": ["label1", "label2", "label3"],
"metric": "F1",
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
},
"init_parameters": {"metric": "f1"},
}
)
assert evaluator._labels == ["label1", "label2", "label3"]
assert evaluator._metric == StatisticalEvaluator.Metric.F1
assert evaluator._regexes_to_ignore is None
assert evaluator._ignore_case is False
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
assert evaluator._metric == StatisticalMetric.F1
class TestStatisticalEvaluatorF1:
def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.F1)
result = evaluator.run(predictions=[])
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
result = evaluator.run(labels=[], predictions=[])
assert len(result) == 1
assert result["result"] == 0.0
def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
with pytest.raises(ValueError):
evaluator.run(predictions)
evaluator.run(labels=labels, predictions=predictions)
def test_run_with_matching_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions)
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 1.0
def test_run_with_single_prediction(self):
labels = ["Source"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
result = evaluator.run(predictions=["Open Source"])
result = evaluator.run(labels=["Source"], predictions=["Open Source"])
assert len(result) == 1
assert result["result"] == pytest.approx(2 / 3)
def test_run_with_mismatched_predictions(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
predictions = ["Open Source", "HaystackAI"]
result = evaluator.run(predictions=predictions)
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_case(self):
labels = ["source", "HAYSTACKAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_case=True)
predictions = ["Open Source", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert len(result) == 1
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_punctuation(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_punctuation=True)
predictions = ["Open Source!", "Haystack.AI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_numbers(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_numbers=True)
predictions = ["Open Source123", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_regex_to_ignore(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+"]
)
predictions = ["Open123 Source", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_multiple_regex_to_ignore(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+", r"[^\w\s]"]
)
predictions = ["Open123! Source", "Haystack.AI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_multiple_ignore_parameters(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels,
metric=StatisticalEvaluator.Metric.F1,
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=[r"[^\w\s\d]+"],
)
predictions = ["Open%123. !$Source", "Haystack.AI##"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
class TestStatisticalEvaluatorExactMatch:
def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.EM)
result = evaluator.run(predictions=[])
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
result = evaluator.run(predictions=[], labels=[])
assert len(result) == 1
assert result["result"] == 0.0
def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
with pytest.raises(ValueError):
evaluator.run(predictions)
evaluator.run(labels=labels, predictions=predictions)
def test_run_with_matching_predictions(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions)
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 1.0
def test_run_with_single_prediction(self):
labels = ["OpenSource"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
result = evaluator.run(predictions=["OpenSource"])
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
result = evaluator.run(labels=["OpenSource"], predictions=["OpenSource"])
assert len(result) == 1
assert result["result"] == 1.0
def test_run_with_mismatched_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = ["Source", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions)
result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1
assert result["result"] == 2 / 3
def test_run_with_ignore_case(self):
labels = ["opensource", "HAYSTACKAI", "llMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_case=True)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions)
assert len(result) == 1
assert result["result"] == 1.0
def test_run_with_ignore_punctuation(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_punctuation=True)
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_ignore_numbers(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_numbers=True)
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_regex_to_ignore(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+"]
)
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_multiple_regex_to_ignore(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+", r"\W+"]
)
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_multiple_ignore_parameters(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels,
metric=StatisticalEvaluator.Metric.EM,
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=[r"[^\w\s\d]+"],
)
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0