feat: Add Semantic Answer Similarity metric (#6877)

* Add SAS metric

* Add release notes

* Round similarity scores for precision consistency

* Add tolerance to tests

* Update haystack/evaluation/eval.py

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* Add types for preprocess_text; Add additional types for f1 and em methods

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Ashwin Mathur 2024-02-02 21:37:52 +05:30 committed by GitHub
parent 461556cca2
commit 393a7993c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 555 additions and 8 deletions

View File

@ -1,4 +1,5 @@
import json
import pytest
from haystack import Pipeline
from haystack.components.readers import ExtractiveReader
@ -140,3 +141,26 @@ def test_extractive_qa_pipeline(tmp_path):
assert f1_custom_parameters["f1"] == 1.0
with open(tmp_path / "f1_score.json", "r") as f:
assert f1_default == json.load(f)
# Test SAS
sas_default = eval_result.calculate_metrics(
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
sas_custom_parameters = eval_result.calculate_metrics(
Metric.SAS,
output_key="answers",
ignore_case=True,
ignore_punctuation=True,
ignore_numbers=True,
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
)
# Save SAS metric results to json
sas_default.save(tmp_path / "sas_score.json")
assert sas_default["sas"] == pytest.approx(1.0)
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
assert sas_custom_parameters["sas"] == pytest.approx(0.9996823, abs=1e-5)
assert sas_custom_parameters["scores"] == pytest.approx([0.999672, 0.999608, 0.999767])
with open(tmp_path / "sas_score.json", "r") as f:
assert sas_default == json.load(f)

View File

@ -1,4 +1,5 @@
import json
import pytest
from haystack import Pipeline
from haystack.components.builders.answer_builder import AnswerBuilder
@ -142,6 +143,29 @@ def test_bm25_rag_pipeline(tmp_path):
with open(tmp_path / "f1_score.json", "r") as f:
assert f1_default == json.load(f)
# Test SAS
sas_default = eval_result.calculate_metrics(
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
sas_custom_parameters = eval_result.calculate_metrics(
Metric.SAS,
output_key="answers",
ignore_case=True,
ignore_punctuation=True,
ignore_numbers=True,
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
)
# Save SAS metric results to json
sas_default.save(tmp_path / "sas_score.json")
assert sas_default["sas"] == pytest.approx(1.0)
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5)
assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5)
with open(tmp_path / "sas_score.json", "r") as f:
assert sas_default == json.load(f)
def test_embedding_retrieval_rag_pipeline(tmp_path):
# Create the RAG pipeline
@ -287,3 +311,26 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
assert f1_custom_parameters["f1"] == 1.0
with open(tmp_path / "f1_score.json", "r") as f:
assert f1_default == json.load(f)
# Test SAS
sas_default = eval_result.calculate_metrics(
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
sas_custom_parameters = eval_result.calculate_metrics(
Metric.SAS,
output_key="answers",
ignore_case=True,
ignore_punctuation=True,
ignore_numbers=True,
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
)
# Save SAS metric results to json
sas_default.save(tmp_path / "sas_score.json")
assert sas_default["sas"] == pytest.approx(1.0)
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5)
assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5)
with open(tmp_path / "sas_score.json", "r") as f:
assert sas_default == json.load(f)

View File

@ -1,5 +1,5 @@
import collections
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Union, Optional
import numpy as np
@ -8,6 +8,13 @@ from haystack.core.component import Component
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
from haystack.evaluation.metrics import Metric, MetricsResult
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, expit
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from transformers import AutoConfig
class EvaluationResult:
"""
@ -89,7 +96,12 @@ class EvaluationResult:
return f1
def _calculate_f1(
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
self,
output_key: str,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
) -> MetricsResult:
"""
Calculates the F1 score between two lists of predictions and labels.
@ -103,7 +115,7 @@ class EvaluationResult:
comparison. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
:return: A MetricsResult object containing the calculated F1 score.
"""
predictions = get_answers_from_output(
@ -136,7 +148,12 @@ class EvaluationResult:
return MetricsResult({"f1": f1})
def _calculate_em(
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
self,
output_key: str,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
) -> MetricsResult:
"""
Calculates the Exact Match (EM) score between two lists of predictions and labels.
@ -175,8 +192,106 @@ class EvaluationResult:
return MetricsResult({"exact_match": exact_match_score})
def _calculate_sas(self):
return MetricsResult({"exact_match": None})
def _calculate_sas(
self,
output_key: str,
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
batch_size: int = 32,
device: Optional[ComponentDevice] = None,
token: Optional[Union[str, bool]] = None,
) -> MetricsResult:
"""
Calculates the Semantic Answer Similarity (SAS) score between two lists of predictions and labels.
Semantic Answer Similarity (SAS) score measures the Transformer-based similarity between the predicted text and
the corresponding ground truth label.
:param output_key: The key of the output to use for comparison.
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
comparison. Defaults to False.
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
before comparison. Defaults to False.
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
a downloadable model.
:param batch_size: Number of prediction-label pairs to encode at once.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The token to use as HTTP bearer authorization for private models from Huggingface.
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
Additional information can be found here:
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
:return: A MetricsResult object containing the calculated Semantic Answer Similarity (SAS) score and the
list of similarity scores obtained for each prediction-label pair.
"""
metrics_import.check()
predictions = get_answers_from_output(
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
)
labels = get_answers_from_output(
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
)
if len(predictions) != len(labels):
raise ValueError("The number of predictions and labels must be the same.")
if len(predictions) == len(labels) == 0:
# Return SAS as 0 for no inputs
return MetricsResult({"sas": 0.0, "scores": [0.0]})
predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
config = AutoConfig.from_pretrained(model, use_auth_token=token)
cross_encoder_used = False
if config.architectures:
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
device = ComponentDevice.resolve_device(device)
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
# For Cross Encoders we create a list of pairs of predictions and labels
similarity_model = CrossEncoder(
model,
device=device.to_torch_str(),
tokenizer_args={"use_auth_token": token},
automodel_args={"use_auth_token": token},
)
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
similarity_scores = similarity_model.predict(sentence_pairs, batch_size=batch_size, convert_to_numpy=True)
# All Cross Encoders do not return a set of logits scores that are normalized
# We normalize scores if they are larger than 1
if (similarity_scores > 1).any():
similarity_scores = expit(similarity_scores)
# Convert scores to list of floats from numpy array
similarity_scores = similarity_scores.tolist()
else:
# For Bi-encoders we create embeddings separately for predictions and labels
similarity_model = SentenceTransformer(model, device=device.to_torch_str(), use_auth_token=token)
pred_embeddings = similarity_model.encode(predictions, batch_size=batch_size, convert_to_tensor=True)
label_embeddings = similarity_model.encode(labels, batch_size=batch_size, convert_to_tensor=True)
# Compute cosine-similarities
scores = util.cos_sim(pred_embeddings, label_embeddings)
# cos_sim computes cosine similarity between all pairs of vectors in pred_embeddings and label_embeddings
# It returns a matrix with shape (len(predictions), len(labels))
similarity_scores = [scores[i][i].item() for i in range(len(predictions))]
sas_score = np.mean(similarity_scores)
return MetricsResult({"sas": sas_score, "scores": similarity_scores})
def eval(

View File

@ -1,10 +1,14 @@
import re
import string
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
def preprocess_text(
texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
texts: List[str],
regexes_to_ignore: Optional[List[str]] = None,
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
) -> List[str]:
"""
Preprocess the outputs of the runnable to remove unwanted characters.

View File

@ -0,0 +1,10 @@
---
features:
- |
Adds support for the Semantic Answer Similarity (SAS) metric to `EvaluationResult.calculate_metrics(...)`:
```python
from haystack.evaluation.metrics import Metric
sas_metric = eval_result.calculate_metrics(
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
```

View File

@ -0,0 +1,347 @@
import pytest
from haystack import Pipeline
from haystack.dataclasses import GeneratedAnswer
from haystack.evaluation.eval import EvaluationResult
class TestSAS:
def create_evaluation_result(self, predictions, labels):
"""
Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the
Semantic Answer Similarity (SAS) Metric.
"""
runnable = Pipeline()
inputs = []
outputs = [
{"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}}
for pred in predictions
]
expected_outputs = [
{"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}}
for label in labels
]
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
return evaluation_result
def test_sas_empty_inputs(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with empty inputs.
"""
runnable = Pipeline()
inputs = []
outputs = [
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
]
expected_outputs = [
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
{"answer_builder": {"answers": []}},
]
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
# Expecting 0% SAS for empty inputs
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
assert sas_result["sas"] == 0.0
assert sas_result["scores"] == [0.0]
def test_calculate_sas_with_different_lengths(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with default parameters.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
with pytest.raises(ValueError, match="The number of predictions and labels must be the same."):
evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
@pytest.mark.integration
def test_sas_same_inputs(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with default parameters.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_single_word(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with single-word inputs.
"""
predictions = ["A construction budget of US $2.3 billion"]
labels = ["US $2.3 billion"]
evaluation_result = self.create_evaluation_result(predictions, labels)
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
assert sas_result["sas"] == pytest.approx(0.689089, abs=1e-5)
assert sas_result["scores"] == pytest.approx([0.689089], abs=1e-5)
@pytest.mark.integration
def test_sas_negative_case(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with deliberately mismatched predictions and labels.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"US $2.3 billion",
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
"Japan was transformed into a modernized world power after the Meiji Restoration.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
)
assert sas_result["sas"] == pytest.approx(0.8227189)
assert sas_result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
@pytest.mark.integration
def test_sas_ignore_case(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring case sensitivity.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 BILLION",
"The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.",
"The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# SAS after case ignoring
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", ignore_case=True
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_ignore_punctuation(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring punctuation.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# SAS after ignoring punctuation
sas_result = evaluation_result._calculate_sas(
output_key="answers",
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
ignore_punctuation=True,
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_ignore_numbers(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring numbers.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# SAS after ignoring numbers
sas_result = evaluation_result._calculate_sas(
output_key="answers",
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
ignore_numbers=True,
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_regex_ignore(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring specific regex patterns.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $10.3 billion",
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore numeric patterns
regex_to_ignore = [r"\d+"]
sas_result = evaluation_result._calculate_sas(
output_key="answers",
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
regexes_to_ignore=regex_to_ignore,
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_multiple_ignore_regex(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US #10.3 billion",
"The Eiffel Tower!!, completed in 2005, symbolizes Paris's cultural magnificence.",
"The **Meiji Restoration**, in 1989, transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore numeric patterns and punctuation excluding whitespaces
regex_to_ignore = [r"\d+", r"[^\w\s]"]
sas_result = evaluation_result._calculate_sas(
output_key="answers",
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
regexes_to_ignore=regex_to_ignore,
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_multiple_ignore_combination(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters combined.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US #10.3 BILLION",
"The EIFFEL TOWER!!, completed in 2005, symbolizes Paris's cultural magnificence.",
"The **MEIJI RESTORATION**, in 1989, transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
# Ignore only special characters using regex
regex_to_ignore = [r"[^\w\s\d]+"]
sas_result = evaluation_result._calculate_sas(
output_key="answers",
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
ignore_numbers=True,
ignore_punctuation=True,
ignore_case=True,
regexes_to_ignore=regex_to_ignore,
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_bi_encoder(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score using a Bi-Encoder model.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="sentence-transformers/all-mpnet-base-v2"
)
assert sas_result["sas"] == pytest.approx(1.0)
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_sas_cross_encoder(self):
"""
Test calculation of Semantic Answer Similarity (SAS) Score using a Cross Encoder model.
"""
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
labels = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluation_result = self.create_evaluation_result(predictions, labels)
sas_result = evaluation_result._calculate_sas(
output_key="answers", model="cross-encoder/ms-marco-MiniLM-L-6-v2"
)
assert sas_result["sas"] == pytest.approx(0.999967, abs=1e-5)
assert sas_result["scores"] == pytest.approx(
[0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5
)