feat: Add Exact Match metric (#6696)

* Add exact match metric

* Add release notes

* Cleanup comments in test_eval_exact_match.py

* Create separate preprocessing function; Add output_key parameter

* Update release note

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
Co-authored-by: Julian Risch <julian.risch@deepset.ai>
This commit is contained in:
Ashwin Mathur 2024-01-22 14:27:04 +05:30 committed by GitHub
parent 8a08ab52e1
commit a238c6dd51
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 658 additions and 37 deletions

View File

@ -35,11 +35,7 @@ def test_extractive_qa_pipeline(tmp_path):
query="Who lives in Paris?",
score=0.7713339924812317,
data="Jean and I",
document=Document(
id="6c90b78ad94e4e634e2a067b5fe2d26d4ce95405ec222cbaefaeb09ab4dce81e",
content="My name is Jean and I live in Paris.",
score=0.33144005810482535,
),
document=Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
context=None,
document_offset=ExtractedAnswer.Span(start=11, end=21),
context_offset=None,
@ -65,11 +61,7 @@ def test_extractive_qa_pipeline(tmp_path):
query="Who lives in Berlin?",
score=0.7047999501228333,
data="Mark and I",
document=Document(
id="10a183e965c2e107e20507c717f16559c58a8ba4bc7c577ea8dc32a8d6ca7a20",
content="My name is Mark and I live in Berlin.",
score=0.33144005810482535,
),
document=Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
context=None,
document_offset=ExtractedAnswer.Span(start=11, end=21),
context_offset=None,
@ -95,11 +87,7 @@ def test_extractive_qa_pipeline(tmp_path):
query="Who lives in Rome?",
score=0.7661304473876953,
data="Giorgio and I",
document=Document(
id="fb0f1efe94b3c78aa1c4e5a17a5ef8270f70e89d36a3665c8362675e8a769a27",
content="My name is Giorgio and I live in Rome.",
score=0.33144005810482535,
),
document=Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
context=None,
document_offset=ExtractedAnswer.Span(start=11, end=24),
context_offset=None,
@ -127,10 +115,14 @@ def test_extractive_qa_pipeline(tmp_path):
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
assert eval_result.runnable.to_dict() == qa_pipeline.to_dict()
metrics = eval_result.calculate_metrics(Metric.EM)
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
metrics_custom_parameters = eval_result.calculate_metrics(
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
)
# Save metric results to json
metrics.save(tmp_path / "exact_match_score.json")
metrics_default.save(tmp_path / "exact_match_score.json")
assert metrics["exact_match"] == 1.0
assert metrics_default["exact_match"] == 1.0
assert metrics_custom_parameters["exact_match"] == 1.0
with open(tmp_path / "exact_match_score.json", "r") as f:
assert metrics == json.load(f)
assert metrics_default == json.load(f)

View File

@ -7,7 +7,7 @@ from haystack.components.embedders import SentenceTransformersDocumentEmbedder,
from haystack.components.generators import HuggingFaceLocalGenerator
from haystack.components.retrievers.in_memory import InMemoryBM25Retriever, InMemoryEmbeddingRetriever
from haystack.components.writers import DocumentWriter
from haystack.dataclasses import Document
from haystack.dataclasses import Document, GeneratedAnswer
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.evaluation.eval import eval
from haystack.evaluation.metrics import Metric
@ -59,9 +59,54 @@ def test_bm25_rag_pipeline(tmp_path):
]
expected_outputs = [
{"llm": {"replies": ["Jean"]}},
{"llm": {"replies": ["Mark"]}},
{"llm": {"replies": ["Giorgio"]}},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Jean",
query="Who lives in Paris?",
documents=[
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Mark",
query="Who lives in Berlin?",
documents=[
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Giorgio",
query="Who lives in Rome?",
documents=[
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
]
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
@ -71,13 +116,17 @@ def test_bm25_rag_pipeline(tmp_path):
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
metrics = eval_result.calculate_metrics(Metric.EM)
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
metrics_custom_parameters = eval_result.calculate_metrics(
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
)
# Save metric results to json
metrics.save(tmp_path / "exact_match_score.json")
metrics_default.save(tmp_path / "exact_match_score.json")
assert metrics["exact_match"] == 1.0
assert metrics_default["exact_match"] == 1.0
assert metrics_custom_parameters["exact_match"] == 1.0
with open(tmp_path / "exact_match_score.json", "r") as f:
assert metrics == json.load(f)
assert metrics_default == json.load(f)
def test_embedding_retrieval_rag_pipeline(tmp_path):
@ -142,9 +191,54 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
]
expected_outputs = [
{"llm": {"replies": ["Jean"]}},
{"llm": {"replies": ["Mark"]}},
{"llm": {"replies": ["Giorgio"]}},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Jean",
query="Who lives in Paris?",
documents=[
Document(content="My name is Jean and I live in Paris.", score=0.33144005810482535),
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Mark",
query="Who lives in Berlin?",
documents=[
Document(content="My name is Mark and I live in Berlin.", score=0.33144005810482535),
Document(content="My name is Giorgio and I live in Rome.", score=-0.17938556566116537),
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
{
"answer_builder": {
"answers": [
GeneratedAnswer(
data="Giorgio",
query="Who lives in Rome?",
documents=[
Document(content="My name is Giorgio and I live in Rome.", score=0.33144005810482535),
Document(content="My name is Mark and I live in Berlin.", score=-0.17938556566116537),
Document(content="My name is Jean and I live in Paris.", score=-0.17938556566116537),
],
meta={},
)
]
}
},
]
eval_result = eval(rag_pipeline, inputs=inputs, expected_outputs=expected_outputs)
@ -154,10 +248,14 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
assert len(eval_result.outputs) == len(expected_outputs) == len(inputs)
assert eval_result.runnable.to_dict() == rag_pipeline.to_dict()
metrics = eval_result.calculate_metrics(Metric.EM)
metrics_default = eval_result.calculate_metrics(Metric.EM, output_key="answers")
metrics_custom_parameters = eval_result.calculate_metrics(
Metric.EM, output_key="answers", ignore_case=True, ignore_punctuation=True, ignore_numbers=True
)
# Save metric results to json
metrics.save(tmp_path / "exact_match_score.json")
metrics_default.save(tmp_path / "exact_match_score.json")
assert metrics["exact_match"] == 1.0
assert metrics_default["exact_match"] == 1.0
assert metrics_custom_parameters["exact_match"] == 1.0
with open(tmp_path / "exact_match_score.json", "r") as f:
assert metrics == json.load(f)
assert metrics_default == json.load(f)

View File

@ -1,7 +1,10 @@
from typing import Any, Callable, Dict, List, Union
import numpy as np
from haystack import Pipeline
from haystack.core.component import Component
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
from haystack.evaluation.metrics import Metric, MetricsResult
@ -29,9 +32,15 @@ class EvaluationResult:
self.outputs = outputs
self.expected_outputs = expected_outputs
# Determine the type of the runnable
if str(type(runnable).__name__) == "Pipeline":
self.runnable_type = "pipeline"
else:
self.runnable_type = "component"
# Mapping of metrics to their corresponding functions.
# This should be kept in sync with the Metric enum
self._supported_metrics = {
self._supported_metrics: Dict[Metric, Callable[..., MetricsResult]] = {
Metric.RECALL: self._calculate_recall,
Metric.MRR: self._calculate_mrr,
Metric.MAP: self._calculate_map,
@ -65,8 +74,45 @@ class EvaluationResult:
def _calculate_f1(self):
return MetricsResult({"f1": None})
def _calculate_em(self):
return MetricsResult({"exact_match": 1.0})
def _calculate_em(
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
) -> MetricsResult:
"""
Calculates the Exact Match (EM) score between two lists of predictions and labels.
Exact Match (EM) score measures the percentage of samples where the predicted text exactly matches the
corresponding ground truth label.
:param output_key: The key of the output to use for comparison.
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
"""
predictions = get_answers_from_output(
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
)
labels = get_answers_from_output(
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
)
if len(predictions) != len(labels):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == len(labels) == 0:
# Return Exact Match as 0 for no inputs
return MetricsResult({"exact_match": 0.0})
predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
score_list = np.array(predictions) == np.array(labels)
exact_match_score = np.mean(score_list)
return MetricsResult({"exact_match": exact_match_score})
def _calculate_sas(self):
return MetricsResult({"exact_match": None})

View File

@ -0,0 +1,65 @@
import re
import string
from typing import Any, Dict, List
def preprocess_text(
texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
) -> List[str]:
"""
Preprocess the outputs of the runnable to remove unwanted characters.
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
matching these regular expressions from the text. Defaults to None.
:param ignore_case (bool, optional): If True, converts all characters to lowercase. Defaults to False.
:param ignore_punctuation (bool, optional): If True, removes punctuation from the text. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from the text. Defaults to False.
:return: A list of preprocessed strings.
"""
if regexes_to_ignore:
combined_regex = "|".join(regexes_to_ignore)
texts = [re.sub(combined_regex, "", text, flags=re.IGNORECASE) for text in texts]
if ignore_case:
texts = [text.lower() for text in texts]
if ignore_punctuation:
translator = str.maketrans("", "", string.punctuation)
texts = [text.translate(translator) for text in texts]
if ignore_numbers:
translator = str.maketrans("", "", string.digits)
texts = [text.translate(translator) for text in texts]
return texts
def get_answers_from_output(outputs: List[Dict[str, Any]], output_key: str, runnable_type: str) -> List[str]:
"""
Extracts the answers from the output of a pipeline or component.
:param outputs: The outputs of the runnable.
:return: List of answers from the runnable output.
"""
answers = []
if runnable_type == "pipeline":
# Iterate over output from each Pipeline run
for output in outputs:
# Iterate over output of component in each Pipeline run
for component_output in output.values():
# Only extract answers based on key
for key in component_output.keys():
if output_key in key:
for generated_answer in component_output[output_key]:
if generated_answer.data:
answers.append(generated_answer.data)
else:
# Iterate over output from each Component run
for output in outputs:
# Only extract answers based on key
for key in output.keys():
if output_key in key:
for generated_answer in output[output_key]:
if generated_answer.data:
answers.append(generated_answer.data)
return answers

View File

@ -0,0 +1,8 @@
---
features:
- |
Adds support for the Exact Match metric to `EvaluationResult.calculate_metrics(...)`:
```python
from haystack.evaluation.metrics import Metric
exact_match_metric = eval_result.calculate_metrics(Metric.EM, output_key="answers")
```

View File

View File

@ -0,0 +1,167 @@
from haystack import Pipeline
from haystack.dataclasses import GeneratedAnswer
from haystack.evaluation.eval import EvaluationResult
class TestExactMatch:
def create_evaluation_result(self, predictions, labels):
"""
Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the exact match.
"""
runnable = Pipeline()
inputs = []
outputs = [
{"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}}
for pred in predictions
]
expected_outputs = [
{"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}}
for label in labels
]
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
return evaluation_result
def test_exact_match_empty_inputs(self):
"""
Test exact match with empty inputs
"""
runnable = Pipeline()
inputs = []
outputs = [
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
]
expected_outputs = [
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
]
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
# Expecting 0% exact match for empty inputs
em_result = evaluation_result._calculate_em(output_key="answers")
assert em_result["exact_match"] == 0.0
def test_exact_match_same_inputs(self):
"""
Test exact match with default parameters
"""
predictions = ["OpenSource", "HaystackAI", "LLMs"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
em_result = evaluation_result._calculate_em(output_key="answers")
assert em_result["exact_match"] == 1.0
def test_exact_match_single_word(self):
"""
Test exact match with single-word inputs
"""
predictions = ["OpenSource"]
labels = ["OpenSource"]
evaluation_result = self.create_evaluation_result(predictions, labels)
em_result = evaluation_result._calculate_em(output_key="answers")
assert em_result["exact_match"] == 1.0
def test_exact_match_negative_case(self):
"""
Test exact match with deliberately mismatched predictions and labels
"""
predictions = ["OpenSource", "HaystackAI", "LLMs"]
labels = ["Source", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Expecting EM to be 2/3 as 2 out of 3 items match
expected_em = 2 / 3
em_result = evaluation_result._calculate_em(output_key="answers")
assert em_result["exact_match"] == expected_em
def test_exact_match_ignore_case(self):
"""
Test exact match with ignoring case sensitivity
"""
predictions = ["OpenSource", "HaystackAI", "LLMs"]
labels = ["opensource", "HAYSTACKAI", "llMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Exact match after case ignoring
em_result = evaluation_result._calculate_em(output_key="answers", ignore_case=True)
assert em_result["exact_match"] == 1.0
def test_exact_match_ignore_punctuation(self):
"""
Test exact match with ignoring punctuation
"""
predictions = ["OpenSource!", "Haystack.AI", "LLMs,"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Exact match after ignoring punctuation
em_result = evaluation_result._calculate_em(output_key="answers", ignore_punctuation=True)
assert em_result["exact_match"] == 1.0
def test_exact_match_ignore_numbers(self):
"""
Test exact match with ignoring numbers
"""
predictions = ["OpenSource123", "HaystackAI", "LLMs456"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Exact match after ignoring numbers
em_result = evaluation_result._calculate_em(output_key="answers", ignore_numbers=True)
assert em_result["exact_match"] == 1.0
def test_exact_match_regex_ignore(self):
"""
Test exact match with ignoring specific regex patterns
"""
predictions = ["Open123Source", "HaystackAI", "LLMs456"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore numeric patterns
regex_to_ignore = [r"\d+"]
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
assert em_result["exact_match"] == 1.0
def test_exact_match_multiple_ignore_regex(self):
"""
Test exact match with multiple ignoring parameters
"""
predictions = ["Open123!Source", "Haystack.AI", "LLMs456,"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore numeric patterns and punctuation using regex
regex_to_ignore = [r"\d+", r"\W+"]
em_result = evaluation_result._calculate_em(output_key="answers", regexes_to_ignore=regex_to_ignore)
assert em_result["exact_match"] == 1.0
def test_exact_match_multiple_ignore_combination(self):
"""
Test exact match with multiple ignoring parameters combined
"""
predictions = ["Open%123!$Source", "Haystack.AI##", "^^LLMs456,"]
labels = ["OpenSource", "HaystackAI", "LLMs"]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore only special characters using regex
regex_to_ignore = [r"[^\w\s\d]+"]
em_result = evaluation_result._calculate_em(
output_key="answers",
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=regex_to_ignore,
)
assert em_result["exact_match"] == 1.0

View File

@ -0,0 +1,245 @@
from haystack.dataclasses import GeneratedAnswer
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
class TestEvalUtils:
def test_extract_answers_from_pipeline_output(self):
"""
Test that the function correctly extracts answers from the output of a pipeline.
"""
outputs = [
{
"answer_builder": {
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
}
},
{
"answer_builder": {
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
}
},
{
"answer_builder": {
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
}
},
]
runnable_type = "pipeline"
output_key = "answers"
expected_answers = ["Jean", "Mark", "Giorgio"]
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_extract_answers_from_component_output(self):
"""
Test that the function correctly extracts answers from the output of a component.
"""
outputs = [
{"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]},
{"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]},
{"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]},
]
runnable_type = "component"
output_key = "answers"
expected_answers = ["Jean", "Mark", "Giorgio"]
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_ignore_other_output_keys(self):
"""
Test that the function only extracts answers and ignores other output keys.
"""
outputs = [
{
"llm": {"replies": ["llm_reply_1"]},
"answer_builder": {
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
},
},
{
"llm": {"replies": ["llm_reply_2"]},
"answer_builder": {
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
},
},
{
"llm": {"replies": ["llm_reply_3"]},
"answer_builder": {
"answers": [GeneratedAnswer(data="Giorgio", query="Who lives in Rome?", documents=[], meta={})]
},
},
]
runnable_type = "pipeline"
output_key = "answers"
expected_answers = ["Jean", "Mark", "Giorgio"]
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_handle_empty_outputs(self):
"""
Test that the function correctly handles empty outputs.
"""
outputs = []
runnable_type = "pipeline"
output_key = "answers"
expected_answers = []
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_handle_missing_keys(self):
"""
Test that the function correctly handles outputs with missing keys.
"""
outputs = [
{
"llm": {"replies": ["llm_reply_1"]},
"answer_builder": {
"answers": [GeneratedAnswer(data="Jean", query="Who lives in Paris?", documents=[], meta={})]
},
},
{
"answer_builder": {
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
}
},
]
runnable_type = "pipeline"
output_key = "answers"
expected_answers = ["Jean", "Mark"]
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_handle_missing_values(self):
"""
Test that the function correctly handles outputs with missing values.
"""
outputs = [
{"answer_builder": {"answers": []}},
{
"answer_builder": {
"answers": [GeneratedAnswer(data="Mark", query="Who lives in Berlin?", documents=[], meta={})]
}
},
]
runnable_type = "pipeline"
output_key = "answers"
expected_answers = ["Mark"]
assert get_answers_from_output(outputs, output_key, runnable_type) == expected_answers
def test_preprocess_text_default_parameters(self):
"""
Test preprocess_text with default parameters.
There should be no changes to the input text.
"""
texts = ["Test, Output-1!", "Test, Output-2!"]
expected_output = ["Test, Output-1!", "Test, Output-2!"]
actual_output = preprocess_text(texts)
assert actual_output == expected_output
def test_preprocess_text_ignore_case(self):
"""
Test preprocess_text with ignore_case=True.
"""
texts = ["Test, Output-1!"]
expected_output = ["test, output-1!"]
actual_output = preprocess_text(texts, ignore_case=True)
assert actual_output == expected_output
def test_preprocess_text_ignore_punctuation(self):
"""
Test preprocess_text with ignore_punctuation=True.
"""
texts = ["Test, Output-1!"]
expected_output = ["Test Output1"]
actual_output = preprocess_text(texts, ignore_punctuation=True)
assert actual_output == expected_output
# Preprocess text with ignore_numbers=True.
def test_preprocess_text_ignore_numbers(self):
"""
Test preprocess_text with ignore_numbers=True. It should be able to remove numbers from the input.
"""
texts = ["Test, Output-1!"]
expected_output = ["Test, Output-!"]
actual_output = preprocess_text(texts, ignore_numbers=True)
assert actual_output == expected_output
def test_preprocess_text_regexes_to_ignore(self):
"""
Test preprocess_text with a list of regex patterns to ignore.
"""
texts = ["Test, Output-1!"]
expected_output = ["Test Output"]
# Use regex patterns to remove digits and non-alphanumeric characters
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"[^\w\s]"])
assert actual_output == expected_output
def test_preprocess_text_empty_list(self):
"""
Test preprocess_text with empty list of texts.
"""
texts = []
expected_output = []
actual_output = preprocess_text(texts)
assert actual_output == expected_output
def test_preprocess_text_all_ignore_parameters(self):
"""
Test preprocess_text with all ignore parameters set to True.
"""
texts = ["Test, Output-1!"]
expected_output = ["test output"]
actual_output = preprocess_text(texts, ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
assert actual_output == expected_output
def test_preprocess_text_regexes_to_ignore_empty_string(self):
"""
Test preprocess_text with regexes_to_ignore=[""].
"""
texts = ["Test, Output-1!"]
expected_output = ["Test, Output-1!"]
actual_output = preprocess_text(texts, regexes_to_ignore=[""])
assert actual_output == expected_output
# Preprocess text with regexes_to_ignore=[".*"].
def test_preprocess_text_regexes_to_ignore_dot_star(self):
"""
Test preprocess_text with regexes_to_ignore=[".*"].
"""
texts = ["Test, Output-1!"]
expected_output = [""]
actual_output = preprocess_text(texts, regexes_to_ignore=[".*"])
assert actual_output == expected_output
def test_preprocess_text_regexes_to_ignore_same_substring(self):
"""
Test preprocess_text with regexes_to_ignore where all the regex patterns match the same substring.
"""
texts = ["Test, Output-1!"]
expected_output = ["Test, Output-!"]
actual_output = preprocess_text(texts, regexes_to_ignore=[r"\d", r"\d"])
assert actual_output == expected_output