mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-19 11:58:44 +00:00
feat: Add Exact Match metric (#6696)
* Add exact match metric * Add release notes * Cleanup comments in test_eval_exact_match.py * Create separate preprocessing function; Add output_key parameter * Update release note --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
8a08ab52e1
commit
a238c6dd51
@ -35,11 +35,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
query="Who lives in Paris?",
|
query="Who lives in Paris?",
|
||||||
score=0.7713339924812317,
|
score=0.7713339924812317,
|
||||||
data="Jean and I",
|
data="Jean and I",
|
||||||
document=Document(
|
document=Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||||
id="6c90b78ad94e4e634e2a067b5fe2d26d4ce95405ec222cbaefaeb09ab4dce81e",
|
|
||||||
content="My name is Jean and I live in Paris.",
|
|
||||||
score=0.33144005810482535,
|
|
||||||
),
|
|
||||||
context=None,
|
context=None,
|
||||||
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
||||||
context_offset=None,
|
context_offset=None,
|
||||||
@ -65,11 +61,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
query="Who lives in Berlin?",
|
query="Who lives in Berlin?",
|
||||||
score=0.7047999501228333,
|
score=0.7047999501228333,
|
||||||
data="Mark and I",
|
data="Mark and I",
|
||||||
document=Document(
|
document=Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||||
id="10a183e965c2e107e20507c717f16559c58a8ba4bc7c577ea8dc32a8d6ca7a20",
|
|
||||||
content="My name is Mark and I live in Berlin.",
|
|
||||||
score=0.33144005810482535,
|
|
||||||
),
|
|
||||||
context=None,
|
context=None,
|
||||||
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
||||||
context_offset=None,
|
context_offset=None,
|
||||||
@ -95,11 +87,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
query="Who lives in Rome?",
|
query="Who lives in Rome?",
|
||||||
score=0.7661304473876953,
|
score=0.7661304473876953,
|
||||||
data="Giorgio and I",
|
data="Giorgio and I",
|
||||||
document=Document(
|
document=Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||||
id="fb0f1efe94b3c78aa1c4e5a17a5ef8270f70e89d36a3665c8362675e8a769a27",
|
|
||||||
content="My name is Giorgio and I live in Rome.",
|
|
||||||
score=0.33144005810482535,
|
|
||||||
),
|
|
||||||
context=None,
|
context=None,
|
||||||
document_offset=ExtractedAnswer.Span(start=11, end=24),
|
document_offset=ExtractedAnswer.Span(start=11, end=24),
|
||||||
context_offset=None,
|
context_offset=None,
|
||||||
@ -127,10 +115,14 @@ def test_extractive_qa_pipeline(tmp_path):
|
|||||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||||
assert eval_result.runnable.to_dict() == qa_pipeline.to_dict()
|
assert eval_result.runnable.to_dict() == qa_pipeline.to_dict()
|
||||||
|
|
||||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||||
|
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||||
|
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||||
|
)
|
||||||
# Save metric results to json
|
# Save metric results to json
|
||||||
metrics.save(tmp_path / "exact_match_score.json")
|
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||||
|
|
||||||
assert metrics["exact_match"] == 1.0
|
assert metrics_default["exact_match"] == 1.0
|
||||||
|
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||||
assert metrics == json.load(f)
|
assert metrics_default == json.load(f)
|
||||||
|
@ -7,7 +7,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder,
|
|||||||
from haystack.components.generators import HuggingFaceLocalGenerator
|
from haystack.components.generators import HuggingFaceLocalGenerator
|
||||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||||
from haystack.components.writers import DocumentWriter
|
from haystack.components.writers import DocumentWriter
|
||||||
from haystack.dataclasses import Document
|
from haystack.dataclasses import Document, GeneratedAnswer
|
||||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||||
from haystack.evaluation.eval import eval
|
from haystack.evaluation.eval import eval
|
||||||
from haystack.evaluation.metrics import Metric
|
from haystack.evaluation.metrics import Metric
|
||||||
@ -59,9 +59,54 @@ def test_bm25_rag_pipeline(tmp_path):
|
|||||||
]
|
]
|
||||||
|
|
||||||
expected_outputs = [
|
expected_outputs = [
|
||||||
{"llm": {"replies": ["Jean"]}},
|
{
|
||||||
{"llm": {"replies": ["Mark"]}},
|
"answer_builder": {
|
||||||
{"llm": {"replies": ["Giorgio"]}},
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Jean",
|
||||||
|
query="Who lives in Paris?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Mark",
|
||||||
|
query="Who lives in Berlin?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Giorgio",
|
||||||
|
query="Who lives in Rome?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
||||||
@ -71,13 +116,17 @@ def test_bm25_rag_pipeline(tmp_path):
|
|||||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||||
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
||||||
|
|
||||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||||
|
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||||
|
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||||
|
)
|
||||||
# Save metric results to json
|
# Save metric results to json
|
||||||
metrics.save(tmp_path / "exact_match_score.json")
|
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||||
|
|
||||||
assert metrics["exact_match"] == 1.0
|
assert metrics_default["exact_match"] == 1.0
|
||||||
|
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||||
assert metrics == json.load(f)
|
assert metrics_default == json.load(f)
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||||
@ -142,9 +191,54 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
|||||||
]
|
]
|
||||||
|
|
||||||
expected_outputs = [
|
expected_outputs = [
|
||||||
{"llm": {"replies": ["Jean"]}},
|
{
|
||||||
{"llm": {"replies": ["Mark"]}},
|
"answer_builder": {
|
||||||
{"llm": {"replies": ["Giorgio"]}},
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Jean",
|
||||||
|
query="Who lives in Paris?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Mark",
|
||||||
|
query="Who lives in Berlin?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [
|
||||||
|
GeneratedAnswer(
|
||||||
|
data="Giorgio",
|
||||||
|
query="Who lives in Rome?",
|
||||||
|
documents=[
|
||||||
|
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||||
|
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||||
|
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||||
|
],
|
||||||
|
meta={},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
||||||
@ -154,10 +248,14 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
|||||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||||
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
||||||
|
|
||||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||||
|
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||||
|
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||||
|
)
|
||||||
# Save metric results to json
|
# Save metric results to json
|
||||||
metrics.save(tmp_path / "exact_match_score.json")
|
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||||
|
|
||||||
assert metrics["exact_match"] == 1.0
|
assert metrics_default["exact_match"] == 1.0
|
||||||
|
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||||
assert metrics == json.load(f)
|
assert metrics_default == json.load(f)
|
||||||
|
@ -1,7 +1,10 @@
|
|||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from haystack import Pipeline
|
from haystack import Pipeline
|
||||||
from haystack.core.component import Component
|
from haystack.core.component import Component
|
||||||
|
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
|
||||||
from haystack.evaluation.metrics import Metric, MetricsResult
|
from haystack.evaluation.metrics import Metric, MetricsResult
|
||||||
|
|
||||||
|
|
||||||
@ -29,9 +32,15 @@ class EvaluationResult:
|
|||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
self.expected_outputs = expected_outputs
|
self.expected_outputs = expected_outputs
|
||||||
|
|
||||||
|
# Determine the type of the runnable
|
||||||
|
if str(type(runnable).__name__) == "Pipeline":
|
||||||
|
self.runnable_type = "pipeline"
|
||||||
|
else:
|
||||||
|
self.runnable_type = "component"
|
||||||
|
|
||||||
# Mapping of metrics to their corresponding functions.
|
# Mapping of metrics to their corresponding functions.
|
||||||
# This should be kept in sync with the Metric enum
|
# This should be kept in sync with the Metric enum
|
||||||
self._supported_metrics = {
|
self._supported_metrics: Dict[Metric, Callable[..., MetricsResult]] = {
|
||||||
Metric.RECALL: self._calculate_recall,
|
Metric.RECALL: self._calculate_recall,
|
||||||
Metric.MRR: self._calculate_mrr,
|
Metric.MRR: self._calculate_mrr,
|
||||||
Metric.MAP: self._calculate_map,
|
Metric.MAP: self._calculate_map,
|
||||||
@ -65,8 +74,45 @@ class EvaluationResult:
|
|||||||
def _calculate_f1(self):
|
def _calculate_f1(self):
|
||||||
return MetricsResult({"f1": None})
|
return MetricsResult({"f1": None})
|
||||||
|
|
||||||
def _calculate_em(self):
|
def _calculate_em(
|
||||||
return MetricsResult({"exact_match": 1.0})
|
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||||
|
) -> MetricsResult:
|
||||||
|
"""
|
||||||
|
Calculates the Exact Match (EM) score between two lists of predictions and labels.
|
||||||
|
Exact Match (EM) score measures the percentage of samples where the predicted text exactly matches the
|
||||||
|
corresponding ground truth label.
|
||||||
|
|
||||||
|
:param output_key: The key of the output to use for comparison.
|
||||||
|
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
|
||||||
|
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
||||||
|
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
|
||||||
|
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
|
||||||
|
comparison. Defaults to False.
|
||||||
|
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
|
||||||
|
before comparison. Defaults to False.
|
||||||
|
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
|
||||||
|
"""
|
||||||
|
|
||||||
|
predictions = get_answers_from_output(
|
||||||
|
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||||
|
)
|
||||||
|
labels = get_answers_from_output(
|
||||||
|
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(predictions) != len(labels):
|
||||||
|
raise ValueError("The number of predictions and labels must be the same.")
|
||||||
|
if len(predictions) == len(labels) == 0:
|
||||||
|
# Return Exact Match as 0 for no inputs
|
||||||
|
return MetricsResult({"exact_match": 0.0})
|
||||||
|
|
||||||
|
predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||||
|
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||||
|
|
||||||
|
score_list = np.array(predictions) == np.array(labels)
|
||||||
|
exact_match_score = np.mean(score_list)
|
||||||
|
|
||||||
|
return MetricsResult({"exact_match": exact_match_score})
|
||||||
|
|
||||||
def _calculate_sas(self):
|
def _calculate_sas(self):
|
||||||
return MetricsResult({"exact_match": None})
|
return MetricsResult({"exact_match": None})
|
||||||
|
65
haystack/evaluation/eval_utils.py
Normal file
65
haystack/evaluation/eval_utils.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import re
|
||||||
|
import string
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_text(
|
||||||
|
texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Preprocess the outputs of the runnable to remove unwanted characters.
|
||||||
|
|
||||||
|
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
|
||||||
|
matching these regular expressions from the text. Defaults to None.
|
||||||
|
:param ignore_case (bool, optional): If True, converts all characters to lowercase. Defaults to False.
|
||||||
|
:param ignore_punctuation (bool, optional): If True, removes punctuation from the text. Defaults to False.
|
||||||
|
:param ignore_numbers (bool, optional): If True, removes numerical digits from the text. Defaults to False.
|
||||||
|
:return: A list of preprocessed strings.
|
||||||
|
"""
|
||||||
|
if regexes_to_ignore:
|
||||||
|
combined_regex = "|".join(regexes_to_ignore)
|
||||||
|
texts = [re.sub(combined_regex, "", text, flags=re.IGNORECASE) for text in texts]
|
||||||
|
|
||||||
|
if ignore_case:
|
||||||
|
texts = [text.lower() for text in texts]
|
||||||
|
|
||||||
|
if ignore_punctuation:
|
||||||
|
translator = str.maketrans("", "", string.punctuation)
|
||||||
|
texts = [text.translate(translator) for text in texts]
|
||||||
|
|
||||||
|
if ignore_numbers:
|
||||||
|
translator = str.maketrans("", "", string.digits)
|
||||||
|
texts = [text.translate(translator) for text in texts]
|
||||||
|
|
||||||
|
return texts
|
||||||
|
|
||||||
|
|
||||||
|
def get_answers_from_output(outputs: List[Dict[str, Any]], output_key: str, runnable_type: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Extracts the answers from the output of a pipeline or component.
|
||||||
|
|
||||||
|
:param outputs: The outputs of the runnable.
|
||||||
|
:return: List of answers from the runnable output.
|
||||||
|
"""
|
||||||
|
answers = []
|
||||||
|
if runnable_type == "pipeline":
|
||||||
|
# Iterate over output from each Pipeline run
|
||||||
|
for output in outputs:
|
||||||
|
# Iterate over output of component in each Pipeline run
|
||||||
|
for component_output in output.values():
|
||||||
|
# Only extract answers based on key
|
||||||
|
for key in component_output.keys():
|
||||||
|
if output_key in key:
|
||||||
|
for generated_answer in component_output[output_key]:
|
||||||
|
if generated_answer.data:
|
||||||
|
answers.append(generated_answer.data)
|
||||||
|
else:
|
||||||
|
# Iterate over output from each Component run
|
||||||
|
for output in outputs:
|
||||||
|
# Only extract answers based on key
|
||||||
|
for key in output.keys():
|
||||||
|
if output_key in key:
|
||||||
|
for generated_answer in output[output_key]:
|
||||||
|
if generated_answer.data:
|
||||||
|
answers.append(generated_answer.data)
|
||||||
|
return answers
|
8
releasenotes/notes/add-exact-match-a7df21717238b771.yaml
Normal file
8
releasenotes/notes/add-exact-match-a7df21717238b771.yaml
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
---
|
||||||
|
features:
|
||||||
|
- |
|
||||||
|
Adds support for the Exact Match metric to `EvaluationResult.calculate_metrics(...)`:
|
||||||
|
```python
|
||||||
|
from haystack.evaluation.metrics import Metric
|
||||||
|
exact_match_metric = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||||
|
```
|
0
test/evaluation/__init__.py
Normal file
0
test/evaluation/__init__.py
Normal file
167
test/evaluation/test_eval_exact_match.py
Normal file
167
test/evaluation/test_eval_exact_match.py
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
from haystack import Pipeline
|
||||||
|
from haystack.dataclasses import GeneratedAnswer
|
||||||
|
from haystack.evaluation.eval import EvaluationResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestExactMatch:
|
||||||
|
def create_evaluation_result(self, predictions, labels):
|
||||||
|
"""
|
||||||
|
Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the exact match.
|
||||||
|
"""
|
||||||
|
runnable = Pipeline()
|
||||||
|
inputs = []
|
||||||
|
outputs = [
|
||||||
|
{"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}}
|
||||||
|
for pred in predictions
|
||||||
|
]
|
||||||
|
expected_outputs = [
|
||||||
|
{"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}}
|
||||||
|
for label in labels
|
||||||
|
]
|
||||||
|
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||||
|
return evaluation_result
|
||||||
|
|
||||||
|
def test_exact_match_empty_inputs(self):
|
||||||
|
"""
|
||||||
|
Test exact match with empty inputs
|
||||||
|
"""
|
||||||
|
runnable = Pipeline()
|
||||||
|
inputs = []
|
||||||
|
outputs = [
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
]
|
||||||
|
expected_outputs = [
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
]
|
||||||
|
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||||
|
# Expecting 0% exact match for empty inputs
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 0.0
|
||||||
|
|
||||||
|
def test_exact_match_same_inputs(self):
|
||||||
|
"""
|
||||||
|
Test exact match with default parameters
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_single_word(self):
|
||||||
|
"""
|
||||||
|
Test exact match with single-word inputs
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource"]
|
||||||
|
labels = ["OpenSource"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_negative_case(self):
|
||||||
|
"""
|
||||||
|
Test exact match with deliberately mismatched predictions and labels
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
labels = ["Source", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Expecting EM to be 2/3 as 2 out of 3 items match
|
||||||
|
expected_em = 2 / 3
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == expected_em
|
||||||
|
|
||||||
|
def test_exact_match_ignore_case(self):
|
||||||
|
"""
|
||||||
|
Test exact match with ignoring case sensitivity
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
labels = ["opensource", "HAYSTACKAI", "llMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Exact match after case ignoring
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers", ignore_case=True)
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_ignore_punctuation(self):
|
||||||
|
"""
|
||||||
|
Test exact match with ignoring punctuation
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Exact match after ignoring punctuation
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers", ignore_punctuation=True)
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_ignore_numbers(self):
|
||||||
|
"""
|
||||||
|
Test exact match with ignoring numbers
|
||||||
|
"""
|
||||||
|
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Exact match after ignoring numbers
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers", ignore_numbers=True)
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_regex_ignore(self):
|
||||||
|
"""
|
||||||
|
Test exact match with ignoring specific regex patterns
|
||||||
|
"""
|
||||||
|
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Ignore numeric patterns
|
||||||
|
regex_to_ignore = [r"\d+"]
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_multiple_ignore_regex(self):
|
||||||
|
"""
|
||||||
|
Test exact match with multiple ignoring parameters
|
||||||
|
"""
|
||||||
|
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Ignore numeric patterns and punctuation using regex
|
||||||
|
regex_to_ignore = [r"\d+", r"\W+"]
|
||||||
|
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
||||||
|
|
||||||
|
def test_exact_match_multiple_ignore_combination(self):
|
||||||
|
"""
|
||||||
|
Test exact match with multiple ignoring parameters combined
|
||||||
|
"""
|
||||||
|
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
|
||||||
|
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||||
|
|
||||||
|
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||||
|
# Ignore only special characters using regex
|
||||||
|
regex_to_ignore = [r"[^\w\s\d]+"]
|
||||||
|
em_result = evaluation_result._calculate_em(
|
||||||
|
output_key="answers",
|
||||||
|
ignore_numbers=True,
|
||||||
|
ignore_punctuation=True,
|
||||||
|
ignore_case=True,
|
||||||
|
regexes_to_ignore=regex_to_ignore,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert em_result["exact_match"] == 1.0
|
245
test/evaluation/test_eval_utils.py
Normal file
245
test/evaluation/test_eval_utils.py
Normal file
@ -0,0 +1,245 @@
|
|||||||
|
from haystack.dataclasses import GeneratedAnswer
|
||||||
|
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestEvalUtils:
|
||||||
|
def test_extract_answers_from_pipeline_output(self):
|
||||||
|
"""
|
||||||
|
Test that the function correctly extracts answers from the output of a pipeline.
|
||||||
|
"""
|
||||||
|
outputs = [
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
runnable_type = "pipeline"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_extract_answers_from_component_output(self):
|
||||||
|
"""
|
||||||
|
Test that the function correctly extracts answers from the output of a component.
|
||||||
|
"""
|
||||||
|
outputs = [
|
||||||
|
{"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]},
|
||||||
|
{"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]},
|
||||||
|
{"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]},
|
||||||
|
]
|
||||||
|
runnable_type = "component"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_ignore_other_output_keys(self):
|
||||||
|
"""
|
||||||
|
Test that the function only extracts answers and ignores other output keys.
|
||||||
|
"""
|
||||||
|
outputs = [
|
||||||
|
{
|
||||||
|
"llm": {"replies": ["llm_reply_1"]},
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm": {"replies": ["llm_reply_2"]},
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"llm": {"replies": ["llm_reply_3"]},
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
runnable_type = "pipeline"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_handle_empty_outputs(self):
|
||||||
|
"""
|
||||||
|
Test that the function correctly handles empty outputs.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
runnable_type = "pipeline"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = []
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_handle_missing_keys(self):
|
||||||
|
"""
|
||||||
|
Test that the function correctly handles outputs with missing keys.
|
||||||
|
"""
|
||||||
|
outputs = [
|
||||||
|
{
|
||||||
|
"llm": {"replies": ["llm_reply_1"]},
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
runnable_type = "pipeline"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = ["Jean", "Mark"]
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_handle_missing_values(self):
|
||||||
|
"""
|
||||||
|
Test that the function correctly handles outputs with missing values.
|
||||||
|
"""
|
||||||
|
outputs = [
|
||||||
|
{"answer_builder": {"answers": []}},
|
||||||
|
{
|
||||||
|
"answer_builder": {
|
||||||
|
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
]
|
||||||
|
runnable_type = "pipeline"
|
||||||
|
output_key = "answers"
|
||||||
|
expected_answers = ["Mark"]
|
||||||
|
|
||||||
|
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||||
|
|
||||||
|
def test_preprocess_text_default_parameters(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with default parameters.
|
||||||
|
There should be no changes to the input text.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!", "Test, Output-2!"]
|
||||||
|
expected_output = ["Test, Output-1!", "Test, Output-2!"]
|
||||||
|
actual_output = preprocess_text(texts)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_ignore_case(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with ignore_case=True.
|
||||||
|
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["test, output-1!"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, ignore_case=True)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_ignore_punctuation(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with ignore_punctuation=True.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["Test Output1"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, ignore_punctuation=True)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
# Preprocess text with ignore_numbers=True.
|
||||||
|
def test_preprocess_text_ignore_numbers(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with ignore_numbers=True. It should be able to remove numbers from the input.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["Test, Output-!"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, ignore_numbers=True)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_regexes_to_ignore(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with a list of regex patterns to ignore.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["Test Output"]
|
||||||
|
|
||||||
|
# Use regex patterns to remove digits and non-alphanumeric characters
|
||||||
|
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"[^\w\s]"])
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_empty_list(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with empty list of texts.
|
||||||
|
"""
|
||||||
|
texts = []
|
||||||
|
expected_output = []
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_all_ignore_parameters(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with all ignore parameters set to True.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["test output"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_regexes_to_ignore_empty_string(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with regexes_to_ignore=[""].
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["Test, Output-1!"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, regexes_to_ignore=[""])
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
# Preprocess text with regexes_to_ignore=[".*"].
|
||||||
|
def test_preprocess_text_regexes_to_ignore_dot_star(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with regexes_to_ignore=[".*"].
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = [""]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, regexes_to_ignore=[".*"])
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
||||||
|
|
||||||
|
def test_preprocess_text_regexes_to_ignore_same_substring(self):
|
||||||
|
"""
|
||||||
|
Test preprocess_text with regexes_to_ignore where all the regex patterns match the same substring.
|
||||||
|
"""
|
||||||
|
texts = ["Test, Output-1!"]
|
||||||
|
expected_output = ["Test, Output-!"]
|
||||||
|
|
||||||
|
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"\d"])
|
||||||
|
|
||||||
|
assert actual_output == expected_output
|
Loading…
x
Reference in New Issue
Block a user