From 393a7993c329edd3bbb99ceb8fec3a0bc7cfb370 Mon Sep 17 00:00:00 2001 From: Ashwin Mathur <97467100+awinml@users.noreply.github.com> Date: Fri, 2 Feb 2024 21:37:52 +0530 Subject: [PATCH] feat: Add Semantic Answer Similarity metric (#6877) * Add SAS metric * Add release notes * Round similarity scores for precision consistency * Add tolerance to tests * Update haystack/evaluation/eval.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Add types for preprocess_text; Add additional types for f1 and em methods --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> --- .../test_eval_extractive_qa_pipeline.py | 24 ++ e2e/pipelines/test_eval_rag_pipelines.py | 47 +++ haystack/evaluation/eval.py | 127 ++++++- haystack/evaluation/eval_utils.py | 8 +- .../notes/add-sas-b8dbf61c0d78ba19.yaml | 10 + test/evaluation/test_eval_sas.py | 347 ++++++++++++++++++ 6 files changed, 555 insertions(+), 8 deletions(-) create mode 100644 releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml create mode 100644 test/evaluation/test_eval_sas.py diff --git a/e2e/pipelines/test_eval_extractive_qa_pipeline.py b/e2e/pipelines/test_eval_extractive_qa_pipeline.py index f04ddd472..989b2713b 100644 --- a/e2e/pipelines/test_eval_extractive_qa_pipeline.py +++ b/e2e/pipelines/test_eval_extractive_qa_pipeline.py @@ -1,4 +1,5 @@ import json +import pytest from haystack import Pipeline from haystack.components.readers import ExtractiveReader @@ -140,3 +141,26 @@ def test_extractive_qa_pipeline(tmp_path): assert f1_custom_parameters["f1"] == 1.0 with open(tmp_path / "f1_score.json", "r") as f: assert f1_default == json.load(f) + + # Test SAS + sas_default = eval_result.calculate_metrics( + Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + sas_custom_parameters = eval_result.calculate_metrics( + Metric.SAS, + output_key="answers", + ignore_case=True, + ignore_punctuation=True, + ignore_numbers=True, + model="cross-encoder/ms-marco-MiniLM-L-6-v2", + ) + # Save SAS metric results to json + sas_default.save(tmp_path / "sas_score.json") + + assert sas_default["sas"] == pytest.approx(1.0) + assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0]) + assert sas_custom_parameters["sas"] == pytest.approx(0.9996823, abs=1e-5) + assert sas_custom_parameters["scores"] == pytest.approx([0.999672, 0.999608, 0.999767]) + + with open(tmp_path / "sas_score.json", "r") as f: + assert sas_default == json.load(f) diff --git a/e2e/pipelines/test_eval_rag_pipelines.py b/e2e/pipelines/test_eval_rag_pipelines.py index 23e674d22..7dc512977 100644 --- a/e2e/pipelines/test_eval_rag_pipelines.py +++ b/e2e/pipelines/test_eval_rag_pipelines.py @@ -1,4 +1,5 @@ import json +import pytest from haystack import Pipeline from haystack.components.builders.answer_builder import AnswerBuilder @@ -142,6 +143,29 @@ def test_bm25_rag_pipeline(tmp_path): with open(tmp_path / "f1_score.json", "r") as f: assert f1_default == json.load(f) + # Test SAS + sas_default = eval_result.calculate_metrics( + Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + sas_custom_parameters = eval_result.calculate_metrics( + Metric.SAS, + output_key="answers", + ignore_case=True, + ignore_punctuation=True, + ignore_numbers=True, + model="cross-encoder/ms-marco-MiniLM-L-6-v2", + ) + # Save SAS metric results to json + sas_default.save(tmp_path / "sas_score.json") + + assert sas_default["sas"] == pytest.approx(1.0) + assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0]) + assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5) + assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5) + + with open(tmp_path / "sas_score.json", "r") as f: + assert sas_default == json.load(f) + def test_embedding_retrieval_rag_pipeline(tmp_path): # Create the RAG pipeline @@ -287,3 +311,26 @@ def test_embedding_retrieval_rag_pipeline(tmp_path): assert f1_custom_parameters["f1"] == 1.0 with open(tmp_path / "f1_score.json", "r") as f: assert f1_default == json.load(f) + + # Test SAS + sas_default = eval_result.calculate_metrics( + Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + sas_custom_parameters = eval_result.calculate_metrics( + Metric.SAS, + output_key="answers", + ignore_case=True, + ignore_punctuation=True, + ignore_numbers=True, + model="cross-encoder/ms-marco-MiniLM-L-6-v2", + ) + # Save SAS metric results to json + sas_default.save(tmp_path / "sas_score.json") + + assert sas_default["sas"] == pytest.approx(1.0) + assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0]) + assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5) + assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5) + + with open(tmp_path / "sas_score.json", "r") as f: + assert sas_default == json.load(f) diff --git a/haystack/evaluation/eval.py b/haystack/evaluation/eval.py index 7951814c1..3ce375ee2 100644 --- a/haystack/evaluation/eval.py +++ b/haystack/evaluation/eval.py @@ -1,5 +1,5 @@ import collections -from typing import Any, Callable, Dict, List, Union +from typing import Any, Callable, Dict, List, Union, Optional import numpy as np @@ -8,6 +8,13 @@ from haystack.core.component import Component from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text from haystack.evaluation.metrics import Metric, MetricsResult +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, expit + +with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import: + from sentence_transformers import SentenceTransformer, CrossEncoder, util + from transformers import AutoConfig + class EvaluationResult: """ @@ -89,7 +96,12 @@ class EvaluationResult: return f1 def _calculate_f1( - self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False + self, + output_key: str, + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, ) -> MetricsResult: """ Calculates the F1 score between two lists of predictions and labels. @@ -103,7 +115,7 @@ class EvaluationResult: comparison. Defaults to False. :param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels before comparison. Defaults to False. - :return: A MetricsResult object containing the calculated Exact Match (EM) score. + :return: A MetricsResult object containing the calculated F1 score. """ predictions = get_answers_from_output( @@ -136,7 +148,12 @@ class EvaluationResult: return MetricsResult({"f1": f1}) def _calculate_em( - self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False + self, + output_key: str, + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, ) -> MetricsResult: """ Calculates the Exact Match (EM) score between two lists of predictions and labels. @@ -175,8 +192,106 @@ class EvaluationResult: return MetricsResult({"exact_match": exact_match_score}) - def _calculate_sas(self): - return MetricsResult({"exact_match": None}) + def _calculate_sas( + self, + output_key: str, + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, + model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + batch_size: int = 32, + device: Optional[ComponentDevice] = None, + token: Optional[Union[str, bool]] = None, + ) -> MetricsResult: + """ + Calculates the Semantic Answer Similarity (SAS) score between two lists of predictions and labels. + Semantic Answer Similarity (SAS) score measures the Transformer-based similarity between the predicted text and + the corresponding ground truth label. + + :param output_key: The key of the output to use for comparison. + :param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings + matching these regular expressions from both predictions and labels before comparison. Defaults to None. + :param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False. + :param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before + comparison. Defaults to False. + :param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels + before comparison. Defaults to False. + :param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to + a downloadable model. + :param batch_size: Number of prediction-label pairs to encode at once. + :param device: The device on which the model is loaded. If `None`, the default device is automatically + selected. + :param token: The token to use as HTTP bearer authorization for private models from Huggingface. + If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface). + Additional information can be found here: + https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained + :return: A MetricsResult object containing the calculated Semantic Answer Similarity (SAS) score and the + list of similarity scores obtained for each prediction-label pair. + """ + metrics_import.check() + + predictions = get_answers_from_output( + outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type + ) + labels = get_answers_from_output( + outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type + ) + + if len(predictions) != len(labels): + raise ValueError("The number of predictions and labels must be the same.") + if len(predictions) == len(labels) == 0: + # Return SAS as 0 for no inputs + return MetricsResult({"sas": 0.0, "scores": [0.0]}) + + predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) + labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers) + + config = AutoConfig.from_pretrained(model, use_auth_token=token) + cross_encoder_used = False + if config.architectures: + cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures) + + device = ComponentDevice.resolve_device(device) + + # Based on the Model string we can load either Bi-Encoders or Cross Encoders. + # Similarity computation changes for both approaches + + if cross_encoder_used: + # For Cross Encoders we create a list of pairs of predictions and labels + similarity_model = CrossEncoder( + model, + device=device.to_torch_str(), + tokenizer_args={"use_auth_token": token}, + automodel_args={"use_auth_token": token}, + ) + sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)] + similarity_scores = similarity_model.predict(sentence_pairs, batch_size=batch_size, convert_to_numpy=True) + + # All Cross Encoders do not return a set of logits scores that are normalized + # We normalize scores if they are larger than 1 + if (similarity_scores > 1).any(): + similarity_scores = expit(similarity_scores) + + # Convert scores to list of floats from numpy array + similarity_scores = similarity_scores.tolist() + + else: + # For Bi-encoders we create embeddings separately for predictions and labels + similarity_model = SentenceTransformer(model, device=device.to_torch_str(), use_auth_token=token) + pred_embeddings = similarity_model.encode(predictions, batch_size=batch_size, convert_to_tensor=True) + label_embeddings = similarity_model.encode(labels, batch_size=batch_size, convert_to_tensor=True) + + # Compute cosine-similarities + scores = util.cos_sim(pred_embeddings, label_embeddings) + + # cos_sim computes cosine similarity between all pairs of vectors in pred_embeddings and label_embeddings + # It returns a matrix with shape (len(predictions), len(labels)) + similarity_scores = [scores[i][i].item() for i in range(len(predictions))] + + sas_score = np.mean(similarity_scores) + + return MetricsResult({"sas": sas_score, "scores": similarity_scores}) def eval( diff --git a/haystack/evaluation/eval_utils.py b/haystack/evaluation/eval_utils.py index cdd6dd846..3b7d1ca1d 100644 --- a/haystack/evaluation/eval_utils.py +++ b/haystack/evaluation/eval_utils.py @@ -1,10 +1,14 @@ import re import string -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional def preprocess_text( - texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False + texts: List[str], + regexes_to_ignore: Optional[List[str]] = None, + ignore_case: bool = False, + ignore_punctuation: bool = False, + ignore_numbers: bool = False, ) -> List[str]: """ Preprocess the outputs of the runnable to remove unwanted characters. diff --git a/releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml b/releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml new file mode 100644 index 000000000..5078a4c93 --- /dev/null +++ b/releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + Adds support for the Semantic Answer Similarity (SAS) metric to `EvaluationResult.calculate_metrics(...)`: + ```python + from haystack.evaluation.metrics import Metric + sas_metric = eval_result.calculate_metrics( + Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + ``` diff --git a/test/evaluation/test_eval_sas.py b/test/evaluation/test_eval_sas.py new file mode 100644 index 000000000..5b1fb70a8 --- /dev/null +++ b/test/evaluation/test_eval_sas.py @@ -0,0 +1,347 @@ +import pytest + +from haystack import Pipeline +from haystack.dataclasses import GeneratedAnswer +from haystack.evaluation.eval import EvaluationResult + + +class TestSAS: + def create_evaluation_result(self, predictions, labels): + """ + Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the + Semantic Answer Similarity (SAS) Metric. + """ + runnable = Pipeline() + inputs = [] + outputs = [ + {"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}} + for pred in predictions + ] + expected_outputs = [ + {"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}} + for label in labels + ] + evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) + return evaluation_result + + def test_sas_empty_inputs(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with empty inputs. + """ + runnable = Pipeline() + inputs = [] + outputs = [ + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + ] + expected_outputs = [ + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + {"answer_builder": {"answers": []}}, + ] + evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs) + # Expecting 0% SAS for empty inputs + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + assert sas_result["sas"] == 0.0 + assert sas_result["scores"] == [0.0] + + def test_calculate_sas_with_different_lengths(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with default parameters. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + ] + evaluation_result = self.create_evaluation_result(predictions, labels) + + with pytest.raises(ValueError, match="The number of predictions and labels must be the same."): + evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + @pytest.mark.integration + def test_sas_same_inputs(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with default parameters. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluation_result = self.create_evaluation_result(predictions, labels) + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_single_word(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with single-word inputs. + """ + predictions = ["A construction budget of US $2.3 billion"] + labels = ["US $2.3 billion"] + + evaluation_result = self.create_evaluation_result(predictions, labels) + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + assert sas_result["sas"] == pytest.approx(0.689089, abs=1e-5) + assert sas_result["scores"] == pytest.approx([0.689089], abs=1e-5) + + @pytest.mark.integration + def test_sas_negative_case(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with deliberately mismatched predictions and labels. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "US $2.3 billion", + "Paris's cultural magnificence is symbolized by the Eiffel Tower", + "Japan was transformed into a modernized world power after the Meiji Restoration.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + ) + + assert sas_result["sas"] == pytest.approx(0.8227189) + assert sas_result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5) + + @pytest.mark.integration + def test_sas_ignore_case(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with ignoring case sensitivity. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 BILLION", + "The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.", + "The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # SAS after case ignoring + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", ignore_case=True + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_ignore_punctuation(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with ignoring punctuation. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # SAS after ignoring punctuation + sas_result = evaluation_result._calculate_sas( + output_key="answers", + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + ignore_punctuation=True, + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_ignore_numbers(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with ignoring numbers. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # SAS after ignoring numbers + sas_result = evaluation_result._calculate_sas( + output_key="answers", + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + ignore_numbers=True, + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_regex_ignore(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with ignoring specific regex patterns. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $10.3 billion", + "The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1989, transformed Japan into a modernized world power.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore numeric patterns + regex_to_ignore = [r"\d+"] + sas_result = evaluation_result._calculate_sas( + output_key="answers", + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + regexes_to_ignore=regex_to_ignore, + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_multiple_ignore_regex(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US #10.3 billion", + "The Eiffel Tower!!, completed in 2005, symbolizes Paris's cultural magnificence.", + "The **Meiji Restoration**, in 1989, transformed Japan into a modernized world power.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore numeric patterns and punctuation excluding whitespaces + regex_to_ignore = [r"\d+", r"[^\w\s]"] + sas_result = evaluation_result._calculate_sas( + output_key="answers", + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + regexes_to_ignore=regex_to_ignore, + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_multiple_ignore_combination(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters combined. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration, in 1868, transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US #10.3 BILLION", + "The EIFFEL TOWER!!, completed in 2005, symbolizes Paris's cultural magnificence.", + "The **MEIJI RESTORATION**, in 1989, transformed Japan into a modernized world power.", + ] + + evaluation_result = self.create_evaluation_result(predictions, labels) + # Ignore only special characters using regex + regex_to_ignore = [r"[^\w\s\d]+"] + sas_result = evaluation_result._calculate_sas( + output_key="answers", + model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + ignore_numbers=True, + ignore_punctuation=True, + ignore_case=True, + regexes_to_ignore=regex_to_ignore, + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_bi_encoder(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score using a Bi-Encoder model. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluation_result = self.create_evaluation_result(predictions, labels) + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="sentence-transformers/all-mpnet-base-v2" + ) + + assert sas_result["sas"] == pytest.approx(1.0) + assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_sas_cross_encoder(self): + """ + Test calculation of Semantic Answer Similarity (SAS) Score using a Cross Encoder model. + """ + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + labels = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluation_result = self.create_evaluation_result(predictions, labels) + sas_result = evaluation_result._calculate_sas( + output_key="answers", model="cross-encoder/ms-marco-MiniLM-L-6-v2" + ) + + assert sas_result["sas"] == pytest.approx(0.999967, abs=1e-5) + assert sas_result["scores"] == pytest.approx( + [0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5 + )