diff --git a/e2e/pipelines/test_eval_extractive_qa_pipeline.py b/e2e/pipelines/test_eval_extractive_qa_pipeline.py index d5f8fcf3d..57ec1b63d 100644 --- a/e2e/pipelines/test_eval_extractive_qa_pipeline.py +++ b/e2e/pipelines/test_eval_extractive_qa_pipeline.py @@ -35,11 +35,7 @@ def test_extractive_qa_pipeline(tmp_path): query="Who lives in Paris?", score=0.7713339924812317, data="Jean and I", - document=Document( - id="6c90b78ad94e4e634e2a067b5fe2d26d4ce95405ec222cbaefaeb09ab4dce81e", - content="My name is Jean and I live in Paris.", - score=0.33144005810482535, - ), + document=Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535), context=None, document_offset=ExtractedAnswer.Span(start=11, end=21), context_offset=None, @@ -65,11 +61,7 @@ def test_extractive_qa_pipeline(tmp_path): query="Who lives in Berlin?", score=0.7047999501228333, data="Mark and I", - document=Document( - id="10a183e965c2e107e20507c717f16559c58a8ba4bc7c577ea8dc32a8d6ca7a20", - content="My name is Mark and I live in Berlin.", - score=0.33144005810482535, - ), + document=Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535), context=None, document_offset=ExtractedAnswer.Span(start=11, end=21), context_offset=None, @@ -95,11 +87,7 @@ def test_extractive_qa_pipeline(tmp_path): query="Who lives in Rome?", score=0.7661304473876953, data="Giorgio and I", - document=Document( - id="fb0f1efe94b3c78aa1c4e5a17a5ef8270f70e89d36a3665c8362675e8a769a27", - content="My name is Giorgio and I live in Rome.", - score=0.33144005810482535, - ), + document=Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535), context=None, document_offset=ExtractedAnswer.Span(start=11, end=24), context_offset=None, @@ -127,10 +115,14 @@ def test_extractive_qa_pipeline(tmp_path): assert len(eval_result.outputs) == len(expected_outputs) == len(inputs) assert eval_result.runnable.to_dict() == qa_pipeline.to_dict() - metrics = eval_result.calculate_metrics(Metric.EM) + metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers") + metrics_custom_parameters = eval_result.calculate_metrics( + Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True + ) # Save metric results to json - metrics.save(tmp_path / "exact_match_score.json") + metrics_default.save(tmp_path / "exact_match_score.json") - assert metrics["exact_match"] == 1.0 + assert metrics_default["exact_match"] == 1.0 + assert metrics_custom_parameters["exact_match"] == 1.0 with open(tmp_path / "exact_match_score.json", "r") as f: - assert metrics == json.load(f) + assert metrics_default == json.load(f) diff --git a/e2e/pipelines/test_eval_rag_pipelines.py b/e2e/pipelines/test_eval_rag_pipelines.py index 5fa50be4e..34e7a888a 100644 --- a/e2e/pipelines/test_eval_rag_pipelines.py +++ b/e2e/pipelines/test_eval_rag_pipelines.py @@ -7,7 +7,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder, from haystack.components.generators import HuggingFaceLocalGenerator from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever from haystack.components.writers import DocumentWriter -from haystack.dataclasses import Document +from haystack.dataclasses import Document, GeneratedAnswer from haystack.document_stores.in_memory import InMemoryDocumentStore from haystack.evaluation.eval import eval from haystack.evaluation.metrics import Metric @@ -59,9 +59,54 @@ def test_bm25_rag_pipeline(tmp_path): ] expected_outputs = [ - {"llm": {"replies": ["Jean"]}}, - {"llm": {"replies": ["Mark"]}}, - {"llm": {"replies": ["Giorgio"]}}, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Jean", + query="Who lives in Paris?", + documents=[ + Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535), + Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537), + Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Mark", + query="Who lives in Berlin?", + documents=[ + Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535), + Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537), + Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Giorgio", + query="Who lives in Rome?", + documents=[ + Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535), + Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537), + Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, ] eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs) @@ -71,13 +116,17 @@ def test_bm25_rag_pipeline(tmp_path): assert len(eval_result.outputs) == len(expected_outputs) == len(inputs) assert eval_result.runnable.to_dict() == rag_pipeline.to_dict() - metrics = eval_result.calculate_metrics(Metric.EM) + metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers") + metrics_custom_parameters = eval_result.calculate_metrics( + Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True + ) # Save metric results to json - metrics.save(tmp_path / "exact_match_score.json") + metrics_default.save(tmp_path / "exact_match_score.json") - assert metrics["exact_match"] == 1.0 + assert metrics_default["exact_match"] == 1.0 + assert metrics_custom_parameters["exact_match"] == 1.0 with open(tmp_path / "exact_match_score.json", "r") as f: - assert metrics == json.load(f) + assert metrics_default == json.load(f) def test_embedding_retrieval_rag_pipeline(tmp_path): @@ -142,9 +191,54 @@ def test_embedding_retrieval_rag_pipeline(tmp_path): ] expected_outputs = [ - {"llm": {"replies": ["Jean"]}}, - {"llm": {"replies": ["Mark"]}}, - {"llm": {"replies": ["Giorgio"]}}, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Jean", + query="Who lives in Paris?", + documents=[ + Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535), + Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537), + Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Mark", + query="Who lives in Berlin?", + documents=[ + Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535), + Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537), + Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, + { + "answer_builder": { + "answers": [ + GeneratedAnswer( + data="Giorgio", + query="Who lives in Rome?", + documents=[ + Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535), + Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537), + Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537), + ], + meta={}, + ) + ] + } + }, ] eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs) @@ -154,10 +248,14 @@ def test_embedding_retrieval_rag_pipeline(tmp_path): assert len(eval_result.outputs) == len(expected_outputs) == len(inputs) assert eval_result.runnable.to_dict() == rag_pipeline.to_dict() - metrics = eval_result.calculate_metrics(Metric.EM) + metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers") + metrics_custom_parameters = eval_result.calculate_metrics( + Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True + ) # Save metric results to json - metrics.save(tmp_path / "exact_match_score.json") + metrics_default.save(tmp_path / "exact_match_score.json") - assert metrics["exact_match"] == 1.0 + assert metrics_default["exact_match"] == 1.0 + assert metrics_custom_parameters["exact_match"] == 1.0 with open(tmp_path / "exact_match_score.json", "r") as f: - assert metrics == json.load(f) + assert metrics_default == json.load(f) diff --git a/haystack/evaluation/eval.py b/haystack/evaluation/eval.py index 28b4b76bb..49d6a87e1 100644 --- a/haystack/evaluation/eval.py +++ b/haystack/evaluation/eval.py @@ -1,7 +1,10 @@ from typing import Any, Callable, Dict, List, Union +import numpy as np + from haystack import Pipeline from haystack.core.component import Component +from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text from haystack.evaluation.metrics import Metric, MetricsResult @@ -29,9 +32,15 @@ class EvaluationResult: self.outputs = outputs self.expected_outputs = expected_outputs + # Determine the type of the runnable + if str(type(runnable).__name__) == "Pipeline": + self.runnable_type = "pipeline" + else: + self.runnable_type = "component" + # Mapping of metrics to their corresponding functions. # This should be kept in sync with the Metric enum - self._supported_metrics = { + self._supported_metrics: Dict[Metric, Callable[..., MetricsResult]] = { Metric.RECALL: self._calculate_recall, Metric.MRR: self._calculate_mrr, Metric.MAP: self._calculate_map, @@ -65,8 +74,45 @@ class EvaluationResult: def _calculate_f1(self): return MetricsResult({"f1": None}) - def _calculate_em(self): - return MetricsResult({"exact_match": 1.0}) + def _calculate_em( + self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False + ) -> MetricsResult: + """ + Calculates the Exact Match (EM) score between two lists of predictions and labels. + Exact Match (EM) score measures the percentage of samples where the predicted text exactly matches the + corresponding ground truth label. + + :param output_key: The key of the output to use for comparison. + :param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings + matching these regular expressions from both predictions and labels before comparison. Defaults to None. + :param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False. + :param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before + comparison. Defaults to False. + :param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels + before comparison. Defaults to False. + :return: A MetricsResult object containing the calculated Exact Match (EM) score. + """ + + predictions = get_answers_from_output( + outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type + ) + labels = get_answers_from_output( + outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type + ) + + if len(predictions) != len(labels): + raise ValueError("The number of predictions and labels must be the same.") + if len(predictions) == len(labels) == 0: + # Return Exact Match as 0 for no inputs + return MetricsResult({"exact_match": 0.0}) + + predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) + labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) + + score_list = np.array(predictions) == np.array(labels) + exact_match_score = np.mean(score_list) + + return MetricsResult({"exact_match": exact_match_score}) def _calculate_sas(self): return MetricsResult({"exact_match": None}) diff --git a/haystack/evaluation/eval_utils.py b/haystack/evaluation/eval_utils.py new file mode 100644 index 000000000..cdd6dd846 --- /dev/null +++ b/haystack/evaluation/eval_utils.py @@ -0,0 +1,65 @@ +import re +import string +from typing import Any, Dict, List + + +def preprocess_text( + texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False +) -> List[str]: + """ + Preprocess the outputs of the runnable to remove unwanted characters. + + :param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings + matching these regular expressions from the text. Defaults to None. + :param ignore_case (bool, optional): If True, converts all characters to lowercase. Defaults to False. + :param ignore_punctuation (bool, optional): If True, removes punctuation from the text. Defaults to False. + :param ignore_numbers (bool, optional): If True, removes numerical digits from the text. Defaults to False. + :return: A list of preprocessed strings. + """ + if regexes_to_ignore: + combined_regex = "|".join(regexes_to_ignore) + texts = [re.sub(combined_regex, "", text, flags=re.IGNORECASE) for text in texts] + + if ignore_case: + texts = [text.lower() for text in texts] + + if ignore_punctuation: + translator = str.maketrans("", "", string.punctuation) + texts = [text.translate(translator) for text in texts] + + if ignore_numbers: + translator = str.maketrans("", "", string.digits) + texts = [text.translate(translator) for text in texts] + + return texts + + +def get_answers_from_output(outputs: List[Dict[str, Any]], output_key: str, runnable_type: str) -> List[str]: + """ + Extracts the answers from the output of a pipeline or component. + + :param outputs: The outputs of the runnable. + :return: List of answers from the runnable output. + """ + answers = [] + if runnable_type == "pipeline": + # Iterate over output from each Pipeline run + for output in outputs: + # Iterate over output of component in each Pipeline run + for component_output in output.values(): + # Only extract answers based on key + for key in component_output.keys(): + if output_key in key: + for generated_answer in component_output[output_key]: + if generated_answer.data: + answers.append(generated_answer.data) + else: + # Iterate over output from each Component run + for output in outputs: + # Only extract answers based on key + for key in output.keys(): + if output_key in key: + for generated_answer in output[output_key]: + if generated_answer.data: + answers.append(generated_answer.data) + return answers diff --git a/releasenotes/notes/add-exact-match-a7df21717238b771.yaml b/releasenotes/notes/add-exact-match-a7df21717238b771.yaml new file mode 100644 index 000000000..f553cac7a --- /dev/null +++ b/releasenotes/notes/add-exact-match-a7df21717238b771.yaml @@ -0,0 +1,8 @@ +--- +features: + - | + Adds support for the Exact Match metric to `EvaluationResult.calculate_metrics(...)`: + ```python + from haystack.evaluation.metrics import Metric + exact_match_metric = eval_result.calculate_metrics(Metric.EM, output_key="answers") + ``` diff --git a/test/evaluation/__init__.py b/test/evaluation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/evaluation/test_eval_exact_match.py b/test/evaluation/test_eval_exact_match.py new file mode 100644 index 000000000..ad0b8930d --- /dev/null +++ b/test/evaluation/test_eval_exact_match.py @@ -0,0 +1,167 @@ +from haystack import Pipeline +from haystack.dataclasses import GeneratedAnswer +from haystack.evaluation.eval import EvaluationResult + + +class TestExactMatch: + def create_evaluation_result(self, predictions, labels): + """ + Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the exact match. + """ + runnable = Pipeline() + inputs = [] + outputs = [ + {"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}} + for pred in predictions + ] + expected_outputs = [ + {"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}} + for label in labels + ] + evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) + return evaluation_result + + def test_exact_match_empty_inputs(self): + """ + Test exact match with empty inputs + """ + runnable = Pipeline() + inputs = [] + outputs = [ + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + ] + expected_outputs = [ + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + ] + evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) + # Expecting 0% exact match for empty inputs + em_result = evaluation_result._calculate_em(output_key="answers") + + assert em_result["exact_match"] == 0.0 + + def test_exact_match_same_inputs(self): + """ + Test exact match with default parameters + """ + predictions = ["OpenSource", "HaystackAI", "LLMs"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + evaluation_result = self.create_evaluation_result(predictions, labels) + em_result = evaluation_result._calculate_em(output_key="answers") + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_single_word(self): + """ + Test exact match with single-word inputs + """ + predictions = ["OpenSource"] + labels = ["OpenSource"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + em_result = evaluation_result._calculate_em(output_key="answers") + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_negative_case(self): + """ + Test exact match with deliberately mismatched predictions and labels + """ + predictions = ["OpenSource", "HaystackAI", "LLMs"] + labels = ["Source", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Expecting EM to be 2/3 as 2 out of 3 items match + expected_em = 2 / 3 + em_result = evaluation_result._calculate_em(output_key="answers") + + assert em_result["exact_match"] == expected_em + + def test_exact_match_ignore_case(self): + """ + Test exact match with ignoring case sensitivity + """ + predictions = ["OpenSource", "HaystackAI", "LLMs"] + labels = ["opensource", "HAYSTACKAI", "llMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Exact match after case ignoring + em_result = evaluation_result._calculate_em(output_key="answers", ignore_case=True) + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_ignore_punctuation(self): + """ + Test exact match with ignoring punctuation + """ + predictions = ["OpenSource!", "Haystack.AI", "LLMs,"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Exact match after ignoring punctuation + em_result = evaluation_result._calculate_em(output_key="answers", ignore_punctuation=True) + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_ignore_numbers(self): + """ + Test exact match with ignoring numbers + """ + predictions = ["OpenSource123", "HaystackAI", "LLMs456"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Exact match after ignoring numbers + em_result = evaluation_result._calculate_em(output_key="answers", ignore_numbers=True) + assert em_result["exact_match"] == 1.0 + + def test_exact_match_regex_ignore(self): + """ + Test exact match with ignoring specific regex patterns + """ + predictions = ["Open123Source", "HaystackAI", "LLMs456"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore numeric patterns + regex_to_ignore = [r"\d+"] + em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore) + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_multiple_ignore_regex(self): + """ + Test exact match with multiple ignoring parameters + """ + predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore numeric patterns and punctuation using regex + regex_to_ignore = [r"\d+", r"\W+"] + em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore) + + assert em_result["exact_match"] == 1.0 + + def test_exact_match_multiple_ignore_combination(self): + """ + Test exact match with multiple ignoring parameters combined + """ + predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"] + labels = ["OpenSource", "HaystackAI", "LLMs"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore only special characters using regex + regex_to_ignore = [r"[^\w\s\d]+"] + em_result = evaluation_result._calculate_em( + output_key="answers", + ignore_numbers=True, + ignore_punctuation=True, + ignore_case=True, + regexes_to_ignore=regex_to_ignore, + ) + + assert em_result["exact_match"] == 1.0 diff --git a/test/evaluation/test_eval_utils.py b/test/evaluation/test_eval_utils.py new file mode 100644 index 000000000..bd7a946a7 --- /dev/null +++ b/test/evaluation/test_eval_utils.py @@ -0,0 +1,245 @@ +from haystack.dataclasses import GeneratedAnswer +from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text + + +class TestEvalUtils: + def test_extract_answers_from_pipeline_output(self): + """ + Test that the function correctly extracts answers from the output of a pipeline. + """ + outputs = [ + { + "answer_builder": { + "answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})] + } + }, + { + "answer_builder": { + "answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})] + } + }, + { + "answer_builder": { + "answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})] + } + }, + ] + + runnable_type = "pipeline" + output_key = "answers" + expected_answers = ["Jean", "Mark", "Giorgio"] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_extract_answers_from_component_output(self): + """ + Test that the function correctly extracts answers from the output of a component. + """ + outputs = [ + {"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]}, + {"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]}, + {"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]}, + ] + runnable_type = "component" + output_key = "answers" + expected_answers = ["Jean", "Mark", "Giorgio"] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_ignore_other_output_keys(self): + """ + Test that the function only extracts answers and ignores other output keys. + """ + outputs = [ + { + "llm": {"replies": ["llm_reply_1"]}, + "answer_builder": { + "answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})] + }, + }, + { + "llm": {"replies": ["llm_reply_2"]}, + "answer_builder": { + "answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})] + }, + }, + { + "llm": {"replies": ["llm_reply_3"]}, + "answer_builder": { + "answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})] + }, + }, + ] + + runnable_type = "pipeline" + output_key = "answers" + expected_answers = ["Jean", "Mark", "Giorgio"] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_handle_empty_outputs(self): + """ + Test that the function correctly handles empty outputs. + """ + outputs = [] + runnable_type = "pipeline" + output_key = "answers" + expected_answers = [] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_handle_missing_keys(self): + """ + Test that the function correctly handles outputs with missing keys. + """ + outputs = [ + { + "llm": {"replies": ["llm_reply_1"]}, + "answer_builder": { + "answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})] + }, + }, + { + "answer_builder": { + "answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})] + } + }, + ] + + runnable_type = "pipeline" + output_key = "answers" + expected_answers = ["Jean", "Mark"] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_handle_missing_values(self): + """ + Test that the function correctly handles outputs with missing values. + """ + outputs = [ + {"answer_builder": {"answers": []}}, + { + "answer_builder": { + "answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})] + } + }, + ] + runnable_type = "pipeline" + output_key = "answers" + expected_answers = ["Mark"] + + assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers + + def test_preprocess_text_default_parameters(self): + """ + Test preprocess_text with default parameters. + There should be no changes to the input text. + """ + texts = ["Test, Output-1!", "Test, Output-2!"] + expected_output = ["Test, Output-1!", "Test, Output-2!"] + actual_output = preprocess_text(texts) + + assert actual_output == expected_output + + def test_preprocess_text_ignore_case(self): + """ + Test preprocess_text with ignore_case=True. + + """ + texts = ["Test, Output-1!"] + expected_output = ["test, output-1!"] + + actual_output = preprocess_text(texts, ignore_case=True) + + assert actual_output == expected_output + + def test_preprocess_text_ignore_punctuation(self): + """ + Test preprocess_text with ignore_punctuation=True. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test Output1"] + + actual_output = preprocess_text(texts, ignore_punctuation=True) + + assert actual_output == expected_output + + # Preprocess text with ignore_numbers=True. + def test_preprocess_text_ignore_numbers(self): + """ + Test preprocess_text with ignore_numbers=True. It should be able to remove numbers from the input. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-!"] + + actual_output = preprocess_text(texts, ignore_numbers=True) + + assert actual_output == expected_output + + def test_preprocess_text_regexes_to_ignore(self): + """ + Test preprocess_text with a list of regex patterns to ignore. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test Output"] + + # Use regex patterns to remove digits and non-alphanumeric characters + actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"[^\w\s]"]) + + assert actual_output == expected_output + + def test_preprocess_text_empty_list(self): + """ + Test preprocess_text with empty list of texts. + """ + texts = [] + expected_output = [] + + actual_output = preprocess_text(texts) + + assert actual_output == expected_output + + def test_preprocess_text_all_ignore_parameters(self): + """ + Test preprocess_text with all ignore parameters set to True. + """ + texts = ["Test, Output-1!"] + expected_output = ["test output"] + + actual_output = preprocess_text(texts, ignore_case=True, ignore_punctuation=True, ignore_numbers=True) + + assert actual_output == expected_output + + def test_preprocess_text_regexes_to_ignore_empty_string(self): + """ + Test preprocess_text with regexes_to_ignore=[""]. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-1!"] + + actual_output = preprocess_text(texts, regexes_to_ignore=[""]) + + assert actual_output == expected_output + + # Preprocess text with regexes_to_ignore=[".*"]. + def test_preprocess_text_regexes_to_ignore_dot_star(self): + """ + Test preprocess_text with regexes_to_ignore=[".*"]. + """ + texts = ["Test, Output-1!"] + expected_output = [""] + + actual_output = preprocess_text(texts, regexes_to_ignore=[".*"]) + + assert actual_output == expected_output + + def test_preprocess_text_regexes_to_ignore_same_substring(self): + """ + Test preprocess_text with regexes_to_ignore where all the regex patterns match the same substring. + """ + texts = ["Test, Output-1!"] + expected_output = ["Test, Output-!"] + + actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"\d"]) + + assert actual_output == expected_output