mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-17 10:58:51 +00:00
feat: Add Exact Match metric (#6696)
* Add exact match metric * Add release notes * Cleanup comments in test_eval_exact_match.py * Create separate preprocessing function; Add output_key parameter * Update release note --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
parent
8a08ab52e1
commit
a238c6dd51
@ -35,11 +35,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
query="Who lives in Paris?",
|
||||
score=0.7713339924812317,
|
||||
data="Jean and I",
|
||||
document=Document(
|
||||
id="6c90b78ad94e4e634e2a067b5fe2d26d4ce95405ec222cbaefaeb09ab4dce81e",
|
||||
content="My name is Jean and I live in Paris.",
|
||||
score=0.33144005810482535,
|
||||
),
|
||||
document=Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||
context=None,
|
||||
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
||||
context_offset=None,
|
||||
@ -65,11 +61,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
query="Who lives in Berlin?",
|
||||
score=0.7047999501228333,
|
||||
data="Mark and I",
|
||||
document=Document(
|
||||
id="10a183e965c2e107e20507c717f16559c58a8ba4bc7c577ea8dc32a8d6ca7a20",
|
||||
content="My name is Mark and I live in Berlin.",
|
||||
score=0.33144005810482535,
|
||||
),
|
||||
document=Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||
context=None,
|
||||
document_offset=ExtractedAnswer.Span(start=11, end=21),
|
||||
context_offset=None,
|
||||
@ -95,11 +87,7 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
query="Who lives in Rome?",
|
||||
score=0.7661304473876953,
|
||||
data="Giorgio and I",
|
||||
document=Document(
|
||||
id="fb0f1efe94b3c78aa1c4e5a17a5ef8270f70e89d36a3665c8362675e8a769a27",
|
||||
content="My name is Giorgio and I live in Rome.",
|
||||
score=0.33144005810482535,
|
||||
),
|
||||
document=Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||
context=None,
|
||||
document_offset=ExtractedAnswer.Span(start=11, end=24),
|
||||
context_offset=None,
|
||||
@ -127,10 +115,14 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||
assert eval_result.runnable.to_dict() == qa_pipeline.to_dict()
|
||||
|
||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
||||
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||
)
|
||||
# Save metric results to json
|
||||
metrics.save(tmp_path / "exact_match_score.json")
|
||||
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||
|
||||
assert metrics["exact_match"] == 1.0
|
||||
assert metrics_default["exact_match"] == 1.0
|
||||
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||
assert metrics == json.load(f)
|
||||
assert metrics_default == json.load(f)
|
||||
|
@ -7,7 +7,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder,
|
||||
from haystack.components.generators import HuggingFaceLocalGenerator
|
||||
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
|
||||
from haystack.components.writers import DocumentWriter
|
||||
from haystack.dataclasses import Document
|
||||
from haystack.dataclasses import Document, GeneratedAnswer
|
||||
from haystack.document_stores.in_memory import InMemoryDocumentStore
|
||||
from haystack.evaluation.eval import eval
|
||||
from haystack.evaluation.metrics import Metric
|
||||
@ -59,9 +59,54 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
{"llm": {"replies": ["Jean"]}},
|
||||
{"llm": {"replies": ["Mark"]}},
|
||||
{"llm": {"replies": ["Giorgio"]}},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Jean",
|
||||
query="Who lives in Paris?",
|
||||
documents=[
|
||||
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Mark",
|
||||
query="Who lives in Berlin?",
|
||||
documents=[
|
||||
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Giorgio",
|
||||
query="Who lives in Rome?",
|
||||
documents=[
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
||||
@ -71,13 +116,17 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
||||
|
||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
||||
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||
)
|
||||
# Save metric results to json
|
||||
metrics.save(tmp_path / "exact_match_score.json")
|
||||
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||
|
||||
assert metrics["exact_match"] == 1.0
|
||||
assert metrics_default["exact_match"] == 1.0
|
||||
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||
assert metrics == json.load(f)
|
||||
assert metrics_default == json.load(f)
|
||||
|
||||
|
||||
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
@ -142,9 +191,54 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
]
|
||||
|
||||
expected_outputs = [
|
||||
{"llm": {"replies": ["Jean"]}},
|
||||
{"llm": {"replies": ["Mark"]}},
|
||||
{"llm": {"replies": ["Giorgio"]}},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Jean",
|
||||
query="Who lives in Paris?",
|
||||
documents=[
|
||||
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Mark",
|
||||
query="Who lives in Berlin?",
|
||||
documents=[
|
||||
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
|
||||
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [
|
||||
GeneratedAnswer(
|
||||
data="Giorgio",
|
||||
query="Who lives in Rome?",
|
||||
documents=[
|
||||
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
|
||||
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
|
||||
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
|
||||
],
|
||||
meta={},
|
||||
)
|
||||
]
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
|
||||
@ -154,10 +248,14 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
|
||||
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
|
||||
|
||||
metrics = eval_result.calculate_metrics(Metric.EM)
|
||||
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||
metrics_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
|
||||
)
|
||||
# Save metric results to json
|
||||
metrics.save(tmp_path / "exact_match_score.json")
|
||||
metrics_default.save(tmp_path / "exact_match_score.json")
|
||||
|
||||
assert metrics["exact_match"] == 1.0
|
||||
assert metrics_default["exact_match"] == 1.0
|
||||
assert metrics_custom_parameters["exact_match"] == 1.0
|
||||
with open(tmp_path / "exact_match_score.json", "r") as f:
|
||||
assert metrics == json.load(f)
|
||||
assert metrics_default == json.load(f)
|
||||
|
@ -1,7 +1,10 @@
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.core.component import Component
|
||||
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
|
||||
from haystack.evaluation.metrics import Metric, MetricsResult
|
||||
|
||||
|
||||
@ -29,9 +32,15 @@ class EvaluationResult:
|
||||
self.outputs = outputs
|
||||
self.expected_outputs = expected_outputs
|
||||
|
||||
# Determine the type of the runnable
|
||||
if str(type(runnable).__name__) == "Pipeline":
|
||||
self.runnable_type = "pipeline"
|
||||
else:
|
||||
self.runnable_type = "component"
|
||||
|
||||
# Mapping of metrics to their corresponding functions.
|
||||
# This should be kept in sync with the Metric enum
|
||||
self._supported_metrics = {
|
||||
self._supported_metrics: Dict[Metric, Callable[..., MetricsResult]] = {
|
||||
Metric.RECALL: self._calculate_recall,
|
||||
Metric.MRR: self._calculate_mrr,
|
||||
Metric.MAP: self._calculate_map,
|
||||
@ -65,8 +74,45 @@ class EvaluationResult:
|
||||
def _calculate_f1(self):
|
||||
return MetricsResult({"f1": None})
|
||||
|
||||
def _calculate_em(self):
|
||||
return MetricsResult({"exact_match": 1.0})
|
||||
def _calculate_em(
|
||||
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||
) -> MetricsResult:
|
||||
"""
|
||||
Calculates the Exact Match (EM) score between two lists of predictions and labels.
|
||||
Exact Match (EM) score measures the percentage of samples where the predicted text exactly matches the
|
||||
corresponding ground truth label.
|
||||
|
||||
:param output_key: The key of the output to use for comparison.
|
||||
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
|
||||
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
||||
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
|
||||
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
|
||||
comparison. Defaults to False.
|
||||
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
|
||||
before comparison. Defaults to False.
|
||||
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
|
||||
"""
|
||||
|
||||
predictions = get_answers_from_output(
|
||||
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||
)
|
||||
labels = get_answers_from_output(
|
||||
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||
)
|
||||
|
||||
if len(predictions) != len(labels):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
if len(predictions) == len(labels) == 0:
|
||||
# Return Exact Match as 0 for no inputs
|
||||
return MetricsResult({"exact_match": 0.0})
|
||||
|
||||
predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||
|
||||
score_list = np.array(predictions) == np.array(labels)
|
||||
exact_match_score = np.mean(score_list)
|
||||
|
||||
return MetricsResult({"exact_match": exact_match_score})
|
||||
|
||||
def _calculate_sas(self):
|
||||
return MetricsResult({"exact_match": None})
|
||||
|
65
haystack/evaluation/eval_utils.py
Normal file
65
haystack/evaluation/eval_utils.py
Normal file
@ -0,0 +1,65 @@
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
def preprocess_text(
|
||||
texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||
) -> List[str]:
|
||||
"""
|
||||
Preprocess the outputs of the runnable to remove unwanted characters.
|
||||
|
||||
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
|
||||
matching these regular expressions from the text. Defaults to None.
|
||||
:param ignore_case (bool, optional): If True, converts all characters to lowercase. Defaults to False.
|
||||
:param ignore_punctuation (bool, optional): If True, removes punctuation from the text. Defaults to False.
|
||||
:param ignore_numbers (bool, optional): If True, removes numerical digits from the text. Defaults to False.
|
||||
:return: A list of preprocessed strings.
|
||||
"""
|
||||
if regexes_to_ignore:
|
||||
combined_regex = "|".join(regexes_to_ignore)
|
||||
texts = [re.sub(combined_regex, "", text, flags=re.IGNORECASE) for text in texts]
|
||||
|
||||
if ignore_case:
|
||||
texts = [text.lower() for text in texts]
|
||||
|
||||
if ignore_punctuation:
|
||||
translator = str.maketrans("", "", string.punctuation)
|
||||
texts = [text.translate(translator) for text in texts]
|
||||
|
||||
if ignore_numbers:
|
||||
translator = str.maketrans("", "", string.digits)
|
||||
texts = [text.translate(translator) for text in texts]
|
||||
|
||||
return texts
|
||||
|
||||
|
||||
def get_answers_from_output(outputs: List[Dict[str, Any]], output_key: str, runnable_type: str) -> List[str]:
|
||||
"""
|
||||
Extracts the answers from the output of a pipeline or component.
|
||||
|
||||
:param outputs: The outputs of the runnable.
|
||||
:return: List of answers from the runnable output.
|
||||
"""
|
||||
answers = []
|
||||
if runnable_type == "pipeline":
|
||||
# Iterate over output from each Pipeline run
|
||||
for output in outputs:
|
||||
# Iterate over output of component in each Pipeline run
|
||||
for component_output in output.values():
|
||||
# Only extract answers based on key
|
||||
for key in component_output.keys():
|
||||
if output_key in key:
|
||||
for generated_answer in component_output[output_key]:
|
||||
if generated_answer.data:
|
||||
answers.append(generated_answer.data)
|
||||
else:
|
||||
# Iterate over output from each Component run
|
||||
for output in outputs:
|
||||
# Only extract answers based on key
|
||||
for key in output.keys():
|
||||
if output_key in key:
|
||||
for generated_answer in output[output_key]:
|
||||
if generated_answer.data:
|
||||
answers.append(generated_answer.data)
|
||||
return answers
|
8
releasenotes/notes/add-exact-match-a7df21717238b771.yaml
Normal file
8
releasenotes/notes/add-exact-match-a7df21717238b771.yaml
Normal file
@ -0,0 +1,8 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds support for the Exact Match metric to `EvaluationResult.calculate_metrics(...)`:
|
||||
```python
|
||||
from haystack.evaluation.metrics import Metric
|
||||
exact_match_metric = eval_result.calculate_metrics(Metric.EM, output_key="answers")
|
||||
```
|
0
test/evaluation/__init__.py
Normal file
0
test/evaluation/__init__.py
Normal file
167
test/evaluation/test_eval_exact_match.py
Normal file
167
test/evaluation/test_eval_exact_match.py
Normal file
@ -0,0 +1,167 @@
|
||||
from haystack import Pipeline
|
||||
from haystack.dataclasses import GeneratedAnswer
|
||||
from haystack.evaluation.eval import EvaluationResult
|
||||
|
||||
|
||||
class TestExactMatch:
|
||||
def create_evaluation_result(self, predictions, labels):
|
||||
"""
|
||||
Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the exact match.
|
||||
"""
|
||||
runnable = Pipeline()
|
||||
inputs = []
|
||||
outputs = [
|
||||
{"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}}
|
||||
for pred in predictions
|
||||
]
|
||||
expected_outputs = [
|
||||
{"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}}
|
||||
for label in labels
|
||||
]
|
||||
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||
return evaluation_result
|
||||
|
||||
def test_exact_match_empty_inputs(self):
|
||||
"""
|
||||
Test exact match with empty inputs
|
||||
"""
|
||||
runnable = Pipeline()
|
||||
inputs = []
|
||||
outputs = [
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
]
|
||||
expected_outputs = [
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
]
|
||||
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||
# Expecting 0% exact match for empty inputs
|
||||
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||
|
||||
assert em_result["exact_match"] == 0.0
|
||||
|
||||
def test_exact_match_same_inputs(self):
|
||||
"""
|
||||
Test exact match with default parameters
|
||||
"""
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_single_word(self):
|
||||
"""
|
||||
Test exact match with single-word inputs
|
||||
"""
|
||||
predictions = ["OpenSource"]
|
||||
labels = ["OpenSource"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_negative_case(self):
|
||||
"""
|
||||
Test exact match with deliberately mismatched predictions and labels
|
||||
"""
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
labels = ["Source", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Expecting EM to be 2/3 as 2 out of 3 items match
|
||||
expected_em = 2 / 3
|
||||
em_result = evaluation_result._calculate_em(output_key="answers")
|
||||
|
||||
assert em_result["exact_match"] == expected_em
|
||||
|
||||
def test_exact_match_ignore_case(self):
|
||||
"""
|
||||
Test exact match with ignoring case sensitivity
|
||||
"""
|
||||
predictions = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
labels = ["opensource", "HAYSTACKAI", "llMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Exact match after case ignoring
|
||||
em_result = evaluation_result._calculate_em(output_key="answers", ignore_case=True)
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_ignore_punctuation(self):
|
||||
"""
|
||||
Test exact match with ignoring punctuation
|
||||
"""
|
||||
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Exact match after ignoring punctuation
|
||||
em_result = evaluation_result._calculate_em(output_key="answers", ignore_punctuation=True)
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_ignore_numbers(self):
|
||||
"""
|
||||
Test exact match with ignoring numbers
|
||||
"""
|
||||
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Exact match after ignoring numbers
|
||||
em_result = evaluation_result._calculate_em(output_key="answers", ignore_numbers=True)
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_regex_ignore(self):
|
||||
"""
|
||||
Test exact match with ignoring specific regex patterns
|
||||
"""
|
||||
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore numeric patterns
|
||||
regex_to_ignore = [r"\d+"]
|
||||
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_multiple_ignore_regex(self):
|
||||
"""
|
||||
Test exact match with multiple ignoring parameters
|
||||
"""
|
||||
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore numeric patterns and punctuation using regex
|
||||
regex_to_ignore = [r"\d+", r"\W+"]
|
||||
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
||||
|
||||
def test_exact_match_multiple_ignore_combination(self):
|
||||
"""
|
||||
Test exact match with multiple ignoring parameters combined
|
||||
"""
|
||||
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
|
||||
labels = ["OpenSource", "HaystackAI", "LLMs"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore only special characters using regex
|
||||
regex_to_ignore = [r"[^\w\s\d]+"]
|
||||
em_result = evaluation_result._calculate_em(
|
||||
output_key="answers",
|
||||
ignore_numbers=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_case=True,
|
||||
regexes_to_ignore=regex_to_ignore,
|
||||
)
|
||||
|
||||
assert em_result["exact_match"] == 1.0
|
245
test/evaluation/test_eval_utils.py
Normal file
245
test/evaluation/test_eval_utils.py
Normal file
@ -0,0 +1,245 @@
|
||||
from haystack.dataclasses import GeneratedAnswer
|
||||
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
|
||||
|
||||
|
||||
class TestEvalUtils:
|
||||
def test_extract_answers_from_pipeline_output(self):
|
||||
"""
|
||||
Test that the function correctly extracts answers from the output of a pipeline.
|
||||
"""
|
||||
outputs = [
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||
}
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
runnable_type = "pipeline"
|
||||
output_key = "answers"
|
||||
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_extract_answers_from_component_output(self):
|
||||
"""
|
||||
Test that the function correctly extracts answers from the output of a component.
|
||||
"""
|
||||
outputs = [
|
||||
{"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]},
|
||||
{"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]},
|
||||
{"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]},
|
||||
]
|
||||
runnable_type = "component"
|
||||
output_key = "answers"
|
||||
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_ignore_other_output_keys(self):
|
||||
"""
|
||||
Test that the function only extracts answers and ignores other output keys.
|
||||
"""
|
||||
outputs = [
|
||||
{
|
||||
"llm": {"replies": ["llm_reply_1"]},
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||
},
|
||||
},
|
||||
{
|
||||
"llm": {"replies": ["llm_reply_2"]},
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||
},
|
||||
},
|
||||
{
|
||||
"llm": {"replies": ["llm_reply_3"]},
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
runnable_type = "pipeline"
|
||||
output_key = "answers"
|
||||
expected_answers = ["Jean", "Mark", "Giorgio"]
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_handle_empty_outputs(self):
|
||||
"""
|
||||
Test that the function correctly handles empty outputs.
|
||||
"""
|
||||
outputs = []
|
||||
runnable_type = "pipeline"
|
||||
output_key = "answers"
|
||||
expected_answers = []
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_handle_missing_keys(self):
|
||||
"""
|
||||
Test that the function correctly handles outputs with missing keys.
|
||||
"""
|
||||
outputs = [
|
||||
{
|
||||
"llm": {"replies": ["llm_reply_1"]},
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
|
||||
},
|
||||
},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
runnable_type = "pipeline"
|
||||
output_key = "answers"
|
||||
expected_answers = ["Jean", "Mark"]
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_handle_missing_values(self):
|
||||
"""
|
||||
Test that the function correctly handles outputs with missing values.
|
||||
"""
|
||||
outputs = [
|
||||
{"answer_builder": {"answers": []}},
|
||||
{
|
||||
"answer_builder": {
|
||||
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
|
||||
}
|
||||
},
|
||||
]
|
||||
runnable_type = "pipeline"
|
||||
output_key = "answers"
|
||||
expected_answers = ["Mark"]
|
||||
|
||||
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
|
||||
|
||||
def test_preprocess_text_default_parameters(self):
|
||||
"""
|
||||
Test preprocess_text with default parameters.
|
||||
There should be no changes to the input text.
|
||||
"""
|
||||
texts = ["Test, Output-1!", "Test, Output-2!"]
|
||||
expected_output = ["Test, Output-1!", "Test, Output-2!"]
|
||||
actual_output = preprocess_text(texts)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_ignore_case(self):
|
||||
"""
|
||||
Test preprocess_text with ignore_case=True.
|
||||
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["test, output-1!"]
|
||||
|
||||
actual_output = preprocess_text(texts, ignore_case=True)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_ignore_punctuation(self):
|
||||
"""
|
||||
Test preprocess_text with ignore_punctuation=True.
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["Test Output1"]
|
||||
|
||||
actual_output = preprocess_text(texts, ignore_punctuation=True)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
# Preprocess text with ignore_numbers=True.
|
||||
def test_preprocess_text_ignore_numbers(self):
|
||||
"""
|
||||
Test preprocess_text with ignore_numbers=True. It should be able to remove numbers from the input.
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["Test, Output-!"]
|
||||
|
||||
actual_output = preprocess_text(texts, ignore_numbers=True)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_regexes_to_ignore(self):
|
||||
"""
|
||||
Test preprocess_text with a list of regex patterns to ignore.
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["Test Output"]
|
||||
|
||||
# Use regex patterns to remove digits and non-alphanumeric characters
|
||||
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"[^\w\s]"])
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_empty_list(self):
|
||||
"""
|
||||
Test preprocess_text with empty list of texts.
|
||||
"""
|
||||
texts = []
|
||||
expected_output = []
|
||||
|
||||
actual_output = preprocess_text(texts)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_all_ignore_parameters(self):
|
||||
"""
|
||||
Test preprocess_text with all ignore parameters set to True.
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["test output"]
|
||||
|
||||
actual_output = preprocess_text(texts, ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_regexes_to_ignore_empty_string(self):
|
||||
"""
|
||||
Test preprocess_text with regexes_to_ignore=[""].
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["Test, Output-1!"]
|
||||
|
||||
actual_output = preprocess_text(texts, regexes_to_ignore=[""])
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
# Preprocess text with regexes_to_ignore=[".*"].
|
||||
def test_preprocess_text_regexes_to_ignore_dot_star(self):
|
||||
"""
|
||||
Test preprocess_text with regexes_to_ignore=[".*"].
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = [""]
|
||||
|
||||
actual_output = preprocess_text(texts, regexes_to_ignore=[".*"])
|
||||
|
||||
assert actual_output == expected_output
|
||||
|
||||
def test_preprocess_text_regexes_to_ignore_same_substring(self):
|
||||
"""
|
||||
Test preprocess_text with regexes_to_ignore where all the regex patterns match the same substring.
|
||||
"""
|
||||
texts = ["Test, Output-1!"]
|
||||
expected_output = ["Test, Output-!"]
|
||||
|
||||
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"\d"])
|
||||
|
||||
assert actual_output == expected_output
|
Loading…
x
Reference in New Issue
Block a user