refactor: Refactor StatisticalEvaluator (#6999)

* Refactor StatisticalEvaluator

* Update StatisticalEvaluator

* Rename StatisticalMetric.from_string to from_str and change internal logic

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Fix tests

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Silvano Cerza 2024-02-15 16:47:35 +01:00 committed by GitHub
parent c82f787b41
commit 2b8a606cb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 226 deletions

View File

@ -1,4 +1,4 @@
from .sas_evaluator import SASEvaluator from .sas_evaluator import SASEvaluator
from .statistical_evaluator import StatisticalEvaluator from .statistical_evaluator import StatisticalEvaluator, StatisticalMetric
__all__ = ["SASEvaluator", "StatisticalEvaluator"] __all__ = ["SASEvaluator", "StatisticalEvaluator", "StatisticalMetric"]

View File

@ -1,6 +1,6 @@
import collections import collections
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Union
from numpy import array as np_array from numpy import array as np_array
from numpy import mean as np_mean from numpy import mean as np_mean
@ -8,7 +8,22 @@ from numpy import mean as np_mean
from haystack import default_from_dict, default_to_dict from haystack import default_from_dict, default_to_dict
from haystack.core.component import component from haystack.core.component import component
from .preprocess import _preprocess_text
class StatisticalMetric(Enum):
"""
Metrics supported by the StatisticalEvaluator.
"""
F1 = "f1"
EM = "exact_match"
@classmethod
def from_str(cls, metric: str) -> "StatisticalMetric":
map = {e.value: e for e in StatisticalMetric}
metric_ = map.get(metric)
if metric_ is None:
raise ValueError(f"Unknown statistical metric '{metric}'")
return metric_
@component @component
@ -22,79 +37,44 @@ class StatisticalEvaluator:
- Exact Match: Measures the proportion of cases where prediction is identical to the expected label. - Exact Match: Measures the proportion of cases where prediction is identical to the expected label.
""" """
class Metric(Enum): def __init__(self, metric: Union[str, StatisticalMetric]):
"""
Supported metrics
"""
F1 = "F1"
EM = "Exact Match"
def __init__(
self,
labels: List[str],
metric: Metric,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
):
""" """
Creates a new instance of StatisticalEvaluator. Creates a new instance of StatisticalEvaluator.
:param labels: The list of expected answers.
:param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match. :param metric: Metric to use for evaluation in this component. Supported metrics are F1 and Exact Match.
:param regexes_to_ignore: A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case: If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation: If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers: If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
""" """
self._labels = labels if isinstance(metric, str):
metric = StatisticalMetric.from_str(metric)
self._metric = metric self._metric = metric
self._regexes_to_ignore = regexes_to_ignore
self._ignore_case = ignore_case
self._ignore_punctuation = ignore_punctuation
self._ignore_numbers = ignore_numbers
self._metric_function = { self._metric_function = {StatisticalMetric.F1: self._f1, StatisticalMetric.EM: self._exact_match}[self._metric]
StatisticalEvaluator.Metric.F1: self._f1,
StatisticalEvaluator.Metric.EM: self._exact_match,
}[self._metric]
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
return default_to_dict( return default_to_dict(self, metric=self._metric.value)
self,
labels=self._labels,
metric=self._metric.value,
regexes_to_ignore=self._regexes_to_ignore,
ignore_case=self._ignore_case,
ignore_punctuation=self._ignore_punctuation,
ignore_numbers=self._ignore_numbers,
)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator": def from_dict(cls, data: Dict[str, Any]) -> "StatisticalEvaluator":
data["init_parameters"]["metric"] = StatisticalEvaluator.Metric(data["init_parameters"]["metric"]) data["init_parameters"]["metric"] = StatisticalMetric(data["init_parameters"]["metric"])
return default_from_dict(cls, data) return default_from_dict(cls, data)
@component.output_types(result=float) @component.output_types(result=float)
def run(self, predictions: List[str]) -> Dict[str, Any]: def run(self, labels: List[str], predictions: List[str]) -> Dict[str, Any]:
if len(predictions) != len(self._labels): """
raise ValueError("The number of predictions and labels must be the same.") Run the StatisticalEvaluator to compute the metric between a list of predictions and a list of labels.
Both must be list of strings of same length.
predictions = _preprocess_text( :param predictions: List of predictions.
predictions, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers :param labels: List of labels against which the predictions are compared.
) :returns: A dictionary with the following outputs:
labels = _preprocess_text( * `result` - Calculated result of the chosen metric.
self._labels, self._regexes_to_ignore, self._ignore_case, self._ignore_punctuation, self._ignore_numbers """
) if len(labels) != len(predictions):
raise ValueError("The number of predictions and labels must be the same.")
return {"result": self._metric_function(labels, predictions)} return {"result": self._metric_function(labels, predictions)}
def _f1(self, labels: List[str], predictions: List[str]): @staticmethod
def _f1(labels: List[str], predictions: List[str]):
""" """
Measure word overlap between predictions and labels. Measure word overlap between predictions and labels.
""" """
@ -120,7 +100,8 @@ class StatisticalEvaluator:
return np_mean(scores) return np_mean(scores)
def _exact_match(self, labels: List[str], predictions: List[str]) -> float: @staticmethod
def _exact_match(labels: List[str], predictions: List[str]) -> float:
""" """
Measure the proportion of cases where predictiond is identical to the the expected label. Measure the proportion of cases where predictiond is identical to the the expected label.
""" """

View File

@ -1,33 +1,23 @@
import pytest import pytest
from haystack.components.eval import StatisticalEvaluator from haystack.components.eval import StatisticalEvaluator, StatisticalMetric
class TestStatisticalEvaluator: class TestStatisticalEvaluator:
def test_init_default(self): def test_init_default(self):
labels = ["label1", "label2", "label3"] evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1) assert evaluator._metric == StatisticalMetric.F1
assert evaluator._labels == labels
assert evaluator._metric == StatisticalEvaluator.Metric.F1 def test_init_with_string(self):
assert evaluator._regexes_to_ignore is None evaluator = StatisticalEvaluator(metric="exact_match")
assert evaluator._ignore_case is False assert evaluator._metric == StatisticalMetric.EM
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
def test_to_dict(self): def test_to_dict(self):
labels = ["label1", "label2", "label3"] evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
expected_dict = { expected_dict = {
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator", "type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": { "init_parameters": {"metric": "f1"},
"labels": labels,
"metric": "F1",
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
},
} }
assert evaluator.to_dict() == expected_dict assert evaluator.to_dict() == expected_dict
@ -35,225 +25,99 @@ class TestStatisticalEvaluator:
evaluator = StatisticalEvaluator.from_dict( evaluator = StatisticalEvaluator.from_dict(
{ {
"type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator", "type": "haystack.components.eval.statistical_evaluator.StatisticalEvaluator",
"init_parameters": { "init_parameters": {"metric": "f1"},
"labels": ["label1", "label2", "label3"],
"metric": "F1",
"regexes_to_ignore": None,
"ignore_case": False,
"ignore_punctuation": False,
"ignore_numbers": False,
},
} }
) )
assert evaluator._labels == ["label1", "label2", "label3"] assert evaluator._metric == StatisticalMetric.F1
assert evaluator._metric == StatisticalEvaluator.Metric.F1
assert evaluator._regexes_to_ignore is None
assert evaluator._ignore_case is False
assert evaluator._ignore_punctuation is False
assert evaluator._ignore_numbers is False
class TestStatisticalEvaluatorF1: class TestStatisticalEvaluatorF1:
def test_run_with_empty_inputs(self): def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.F1) evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
result = evaluator.run(predictions=[]) result = evaluator.run(labels=[], predictions=[])
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 0.0 assert result["result"] == 0.0
def test_run_with_different_lengths(self): def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = [ labels = [
"A construction budget of US $2.3 billion", "A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
] ]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
predictions = [ predictions = [
"A construction budget of US $2.3 billion", "A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.", "The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
] ]
with pytest.raises(ValueError): with pytest.raises(ValueError):
evaluator.run(predictions) evaluator.run(labels=labels, predictions=predictions)
def test_run_with_matching_predictions(self): def test_run_with_matching_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
labels = ["OpenSource", "HaystackAI", "LLMs"] labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
predictions = ["OpenSource", "HaystackAI", "LLMs"] predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions) result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 1.0 assert result["result"] == 1.0
def test_run_with_single_prediction(self): def test_run_with_single_prediction(self):
labels = ["Source"] evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1)
result = evaluator.run(predictions=["Open Source"]) result = evaluator.run(labels=["Source"], predictions=["Open Source"])
assert len(result) == 1 assert len(result) == 1
assert result["result"] == pytest.approx(2 / 3) assert result["result"] == pytest.approx(2 / 3)
def test_run_with_mismatched_predictions(self): def test_run_with_mismatched_predictions(self):
labels = ["Source", "HaystackAI"] labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1) evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
predictions = ["Open Source", "HaystackAI"] predictions = ["Open Source", "HaystackAI"]
result = evaluator.run(predictions=predictions) result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1 assert len(result) == 1
assert result["result"] == pytest.approx(5 / 6) assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_case(self):
labels = ["source", "HAYSTACKAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_case=True)
predictions = ["Open Source", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert len(result) == 1
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_punctuation(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_punctuation=True)
predictions = ["Open Source!", "Haystack.AI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_ignore_numbers(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.F1, ignore_numbers=True)
predictions = ["Open Source123", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_regex_to_ignore(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+"]
)
predictions = ["Open123 Source", "HaystackAI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_multiple_regex_to_ignore(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.F1, regexes_to_ignore=[r"\d+", r"[^\w\s]"]
)
predictions = ["Open123! Source", "Haystack.AI"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
def test_run_with_multiple_ignore_parameters(self):
labels = ["Source", "HaystackAI"]
evaluator = StatisticalEvaluator(
labels=labels,
metric=StatisticalEvaluator.Metric.F1,
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=[r"[^\w\s\d]+"],
)
predictions = ["Open%123. !$Source", "Haystack.AI##"]
result = evaluator.run(predictions=predictions)
assert result["result"] == pytest.approx(5 / 6)
class TestStatisticalEvaluatorExactMatch: class TestStatisticalEvaluatorExactMatch:
def test_run_with_empty_inputs(self): def test_run_with_empty_inputs(self):
evaluator = StatisticalEvaluator(labels=[], metric=StatisticalEvaluator.Metric.EM) evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
result = evaluator.run(predictions=[]) result = evaluator.run(predictions=[], labels=[])
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 0.0 assert result["result"] == 0.0
def test_run_with_different_lengths(self): def test_run_with_different_lengths(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = [ labels = [
"A construction budget of US $2.3 billion", "A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
] ]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
predictions = [ predictions = [
"A construction budget of US $2.3 billion", "A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.", "The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
] ]
with pytest.raises(ValueError): with pytest.raises(ValueError):
evaluator.run(predictions) evaluator.run(labels=labels, predictions=predictions)
def test_run_with_matching_predictions(self): def test_run_with_matching_predictions(self):
labels = ["OpenSource", "HaystackAI", "LLMs"] labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM) evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
predictions = ["OpenSource", "HaystackAI", "LLMs"] predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions) result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 1.0 assert result["result"] == 1.0
def test_run_with_single_prediction(self): def test_run_with_single_prediction(self):
labels = ["OpenSource"] evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM) result = evaluator.run(labels=["OpenSource"], predictions=["OpenSource"])
result = evaluator.run(predictions=["OpenSource"])
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 1.0 assert result["result"] == 1.0
def test_run_with_mismatched_predictions(self): def test_run_with_mismatched_predictions(self):
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
labels = ["Source", "HaystackAI", "LLMs"] labels = ["Source", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM)
predictions = ["OpenSource", "HaystackAI", "LLMs"] predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions) result = evaluator.run(labels=labels, predictions=predictions)
assert len(result) == 1 assert len(result) == 1
assert result["result"] == 2 / 3 assert result["result"] == 2 / 3
def test_run_with_ignore_case(self):
labels = ["opensource", "HAYSTACKAI", "llMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_case=True)
predictions = ["OpenSource", "HaystackAI", "LLMs"]
result = evaluator.run(predictions=predictions)
assert len(result) == 1
assert result["result"] == 1.0
def test_run_with_ignore_punctuation(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_punctuation=True)
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_ignore_numbers(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(labels=labels, metric=StatisticalEvaluator.Metric.EM, ignore_numbers=True)
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_regex_to_ignore(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+"]
)
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_multiple_regex_to_ignore(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels, metric=StatisticalEvaluator.Metric.EM, regexes_to_ignore=[r"\d+", r"\W+"]
)
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0
def test_run_with_multiple_ignore_parameters(self):
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluator = StatisticalEvaluator(
labels=labels,
metric=StatisticalEvaluator.Metric.EM,
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=[r"[^\w\s\d]+"],
)
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
result = evaluator.run(predictions=predictions)
assert result["result"] == 1.0