2024-02-14 16:48:03 +01:00
|
|
|
import pytest
|
|
|
|
|
2024-02-15 17:17:59 +01:00
|
|
|
from haystack.components.evaluators import StatisticalEvaluator, StatisticalMetric
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TestStatisticalEvaluator:
|
|
|
|
def test_init_default(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
|
|
|
assert evaluator._metric == StatisticalMetric.F1
|
|
|
|
|
|
|
|
def test_init_with_string(self):
|
|
|
|
evaluator = StatisticalEvaluator(metric="exact_match")
|
|
|
|
assert evaluator._metric == StatisticalMetric.EM
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
def test_to_dict(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
expected_dict = {
|
2024-02-15 17:17:59 +01:00
|
|
|
"type": "haystack.components.evaluators.statistical_evaluator.StatisticalEvaluator",
|
2024-02-15 16:47:35 +01:00
|
|
|
"init_parameters": {"metric": "f1"},
|
2024-02-14 16:48:03 +01:00
|
|
|
}
|
|
|
|
assert evaluator.to_dict() == expected_dict
|
|
|
|
|
|
|
|
def test_from_dict(self):
|
|
|
|
evaluator = StatisticalEvaluator.from_dict(
|
|
|
|
{
|
2024-02-15 17:17:59 +01:00
|
|
|
"type": "haystack.components.evaluators.statistical_evaluator.StatisticalEvaluator",
|
2024-02-15 16:47:35 +01:00
|
|
|
"init_parameters": {"metric": "f1"},
|
2024-02-14 16:48:03 +01:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2024-02-15 16:47:35 +01:00
|
|
|
assert evaluator._metric == StatisticalMetric.F1
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
|
|
|
|
class TestStatisticalEvaluatorF1:
|
|
|
|
def test_run_with_empty_inputs(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
|
|
|
result = evaluator.run(labels=[], predictions=[])
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 0.0
|
|
|
|
|
|
|
|
def test_run_with_different_lengths(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
2024-02-14 16:48:03 +01:00
|
|
|
labels = [
|
|
|
|
"A construction budget of US $2.3 billion",
|
|
|
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
|
|
]
|
|
|
|
predictions = [
|
|
|
|
"A construction budget of US $2.3 billion",
|
|
|
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
|
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
|
|
|
]
|
|
|
|
with pytest.raises(ValueError):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
def test_run_with_matching_predictions(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
2024-02-14 16:48:03 +01:00
|
|
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
|
|
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
2024-02-15 16:47:35 +01:00
|
|
|
result = evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 1.0
|
|
|
|
|
|
|
|
def test_run_with_single_prediction(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
2024-02-15 16:47:35 +01:00
|
|
|
result = evaluator.run(labels=["Source"], predictions=["Open Source"])
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == pytest.approx(2 / 3)
|
|
|
|
|
|
|
|
def test_run_with_mismatched_predictions(self):
|
|
|
|
labels = ["Source", "HaystackAI"]
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.F1)
|
2024-02-14 16:48:03 +01:00
|
|
|
predictions = ["Open Source", "HaystackAI"]
|
2024-02-15 16:47:35 +01:00
|
|
|
result = evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == pytest.approx(5 / 6)
|
|
|
|
|
|
|
|
|
|
|
|
class TestStatisticalEvaluatorExactMatch:
|
|
|
|
def test_run_with_empty_inputs(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
|
|
|
result = evaluator.run(predictions=[], labels=[])
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 0.0
|
|
|
|
|
|
|
|
def test_run_with_different_lengths(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
2024-02-14 16:48:03 +01:00
|
|
|
labels = [
|
|
|
|
"A construction budget of US $2.3 billion",
|
|
|
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
|
|
]
|
|
|
|
predictions = [
|
|
|
|
"A construction budget of US $2.3 billion",
|
|
|
|
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
|
|
|
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
|
|
|
]
|
|
|
|
with pytest.raises(ValueError):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
def test_run_with_matching_predictions(self):
|
|
|
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
2024-02-14 16:48:03 +01:00
|
|
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
2024-02-15 16:47:35 +01:00
|
|
|
result = evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
|
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 1.0
|
|
|
|
|
|
|
|
def test_run_with_single_prediction(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
|
|
|
result = evaluator.run(labels=["OpenSource"], predictions=["OpenSource"])
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 1.0
|
|
|
|
|
|
|
|
def test_run_with_mismatched_predictions(self):
|
2024-02-15 16:47:35 +01:00
|
|
|
evaluator = StatisticalEvaluator(metric=StatisticalMetric.EM)
|
2024-02-14 16:48:03 +01:00
|
|
|
labels = ["Source", "HaystackAI", "LLMs"]
|
|
|
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
2024-02-15 16:47:35 +01:00
|
|
|
result = evaluator.run(labels=labels, predictions=predictions)
|
2024-02-14 16:48:03 +01:00
|
|
|
assert len(result) == 1
|
|
|
|
assert result["result"] == 2 / 3
|