mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-30 00:30:09 +00:00
feat: Add Semantic Answer Similarity metric (#6877)
* Add SAS metric * Add release notes * Round similarity scores for precision consistency * Add tolerance to tests * Update haystack/evaluation/eval.py Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * Add types for preprocess_text; Add additional types for f1 and em methods --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
461556cca2
commit
393a7993c3
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.readers import ExtractiveReader
|
||||
@ -140,3 +141,26 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
assert f1_custom_parameters["f1"] == 1.0
|
||||
with open(tmp_path / "f1_score.json", "r") as f:
|
||||
assert f1_default == json.load(f)
|
||||
|
||||
# Test SAS
|
||||
sas_default = eval_result.calculate_metrics(
|
||||
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
sas_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.SAS,
|
||||
output_key="answers",
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_numbers=True,
|
||||
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
)
|
||||
# Save SAS metric results to json
|
||||
sas_default.save(tmp_path / "sas_score.json")
|
||||
|
||||
assert sas_default["sas"] == pytest.approx(1.0)
|
||||
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
assert sas_custom_parameters["sas"] == pytest.approx(0.9996823, abs=1e-5)
|
||||
assert sas_custom_parameters["scores"] == pytest.approx([0.999672, 0.999608, 0.999767])
|
||||
|
||||
with open(tmp_path / "sas_score.json", "r") as f:
|
||||
assert sas_default == json.load(f)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.components.builders.answer_builder import AnswerBuilder
|
||||
@ -142,6 +143,29 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
with open(tmp_path / "f1_score.json", "r") as f:
|
||||
assert f1_default == json.load(f)
|
||||
|
||||
# Test SAS
|
||||
sas_default = eval_result.calculate_metrics(
|
||||
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
sas_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.SAS,
|
||||
output_key="answers",
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_numbers=True,
|
||||
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
)
|
||||
# Save SAS metric results to json
|
||||
sas_default.save(tmp_path / "sas_score.json")
|
||||
|
||||
assert sas_default["sas"] == pytest.approx(1.0)
|
||||
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5)
|
||||
assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5)
|
||||
|
||||
with open(tmp_path / "sas_score.json", "r") as f:
|
||||
assert sas_default == json.load(f)
|
||||
|
||||
|
||||
def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
# Create the RAG pipeline
|
||||
@ -287,3 +311,26 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
assert f1_custom_parameters["f1"] == 1.0
|
||||
with open(tmp_path / "f1_score.json", "r") as f:
|
||||
assert f1_default == json.load(f)
|
||||
|
||||
# Test SAS
|
||||
sas_default = eval_result.calculate_metrics(
|
||||
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
sas_custom_parameters = eval_result.calculate_metrics(
|
||||
Metric.SAS,
|
||||
output_key="answers",
|
||||
ignore_case=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_numbers=True,
|
||||
model="cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
)
|
||||
# Save SAS metric results to json
|
||||
sas_default.save(tmp_path / "sas_score.json")
|
||||
|
||||
assert sas_default["sas"] == pytest.approx(1.0)
|
||||
assert sas_default["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
assert sas_custom_parameters["sas"] == pytest.approx(0.9769593, abs=1e-5)
|
||||
assert sas_custom_parameters["scores"] == pytest.approx([0.975823, 0.957218, 0.997837], abs=1e-5)
|
||||
|
||||
with open(tmp_path / "sas_score.json", "r") as f:
|
||||
assert sas_default == json.load(f)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Dict, List, Union, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -8,6 +8,13 @@ from haystack.core.component import Component
|
||||
from haystack.evaluation.eval_utils import get_answers_from_output, preprocess_text
|
||||
from haystack.evaluation.metrics import Metric, MetricsResult
|
||||
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, expit
|
||||
|
||||
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as metrics_import:
|
||||
from sentence_transformers import SentenceTransformer, CrossEncoder, util
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
class EvaluationResult:
|
||||
"""
|
||||
@ -89,7 +96,12 @@ class EvaluationResult:
|
||||
return f1
|
||||
|
||||
def _calculate_f1(
|
||||
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||
self,
|
||||
output_key: str,
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
) -> MetricsResult:
|
||||
"""
|
||||
Calculates the F1 score between two lists of predictions and labels.
|
||||
@ -103,7 +115,7 @@ class EvaluationResult:
|
||||
comparison. Defaults to False.
|
||||
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
|
||||
before comparison. Defaults to False.
|
||||
:return: A MetricsResult object containing the calculated Exact Match (EM) score.
|
||||
:return: A MetricsResult object containing the calculated F1 score.
|
||||
"""
|
||||
|
||||
predictions = get_answers_from_output(
|
||||
@ -136,7 +148,12 @@ class EvaluationResult:
|
||||
return MetricsResult({"f1": f1})
|
||||
|
||||
def _calculate_em(
|
||||
self, output_key: str, regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||
self,
|
||||
output_key: str,
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
) -> MetricsResult:
|
||||
"""
|
||||
Calculates the Exact Match (EM) score between two lists of predictions and labels.
|
||||
@ -175,8 +192,106 @@ class EvaluationResult:
|
||||
|
||||
return MetricsResult({"exact_match": exact_match_score})
|
||||
|
||||
def _calculate_sas(self):
|
||||
return MetricsResult({"exact_match": None})
|
||||
def _calculate_sas(
|
||||
self,
|
||||
output_key: str,
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
batch_size: int = 32,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
) -> MetricsResult:
|
||||
"""
|
||||
Calculates the Semantic Answer Similarity (SAS) score between two lists of predictions and labels.
|
||||
Semantic Answer Similarity (SAS) score measures the Transformer-based similarity between the predicted text and
|
||||
the corresponding ground truth label.
|
||||
|
||||
:param output_key: The key of the output to use for comparison.
|
||||
:param regexes_to_ignore (list, optional): A list of regular expressions. If provided, it removes substrings
|
||||
matching these regular expressions from both predictions and labels before comparison. Defaults to None.
|
||||
:param ignore_case (bool, optional): If True, performs case-insensitive comparison. Defaults to False.
|
||||
:param ignore_punctuation (bool, optional): If True, removes punctuation from both predictions and labels before
|
||||
comparison. Defaults to False.
|
||||
:param ignore_numbers (bool, optional): If True, removes numerical digits from both predictions and labels
|
||||
before comparison. Defaults to False.
|
||||
:param model: SentenceTransformers semantic textual similarity model, should be path or string pointing to
|
||||
a downloadable model.
|
||||
:param batch_size: Number of prediction-label pairs to encode at once.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected.
|
||||
:param token: The token to use as HTTP bearer authorization for private models from Huggingface.
|
||||
If True, will use the token generated when running huggingface-cli login (stored in ~/.huggingface).
|
||||
Additional information can be found here:
|
||||
https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
|
||||
:return: A MetricsResult object containing the calculated Semantic Answer Similarity (SAS) score and the
|
||||
list of similarity scores obtained for each prediction-label pair.
|
||||
"""
|
||||
metrics_import.check()
|
||||
|
||||
predictions = get_answers_from_output(
|
||||
outputs=self.outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||
)
|
||||
labels = get_answers_from_output(
|
||||
outputs=self.expected_outputs, output_key=output_key, runnable_type=self.runnable_type
|
||||
)
|
||||
|
||||
if len(predictions) != len(labels):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
if len(predictions) == len(labels) == 0:
|
||||
# Return SAS as 0 for no inputs
|
||||
return MetricsResult({"sas": 0.0, "scores": [0.0]})
|
||||
|
||||
predictions = preprocess_text(predictions, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||
labels = preprocess_text(labels, regexes_to_ignore, ignore_case, ignore_punctuation, ignore_numbers)
|
||||
|
||||
config = AutoConfig.from_pretrained(model, use_auth_token=token)
|
||||
cross_encoder_used = False
|
||||
if config.architectures:
|
||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||
|
||||
device = ComponentDevice.resolve_device(device)
|
||||
|
||||
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
|
||||
if cross_encoder_used:
|
||||
# For Cross Encoders we create a list of pairs of predictions and labels
|
||||
similarity_model = CrossEncoder(
|
||||
model,
|
||||
device=device.to_torch_str(),
|
||||
tokenizer_args={"use_auth_token": token},
|
||||
automodel_args={"use_auth_token": token},
|
||||
)
|
||||
sentence_pairs = [[pred, label] for pred, label in zip(predictions, labels)]
|
||||
similarity_scores = similarity_model.predict(sentence_pairs, batch_size=batch_size, convert_to_numpy=True)
|
||||
|
||||
# All Cross Encoders do not return a set of logits scores that are normalized
|
||||
# We normalize scores if they are larger than 1
|
||||
if (similarity_scores > 1).any():
|
||||
similarity_scores = expit(similarity_scores)
|
||||
|
||||
# Convert scores to list of floats from numpy array
|
||||
similarity_scores = similarity_scores.tolist()
|
||||
|
||||
else:
|
||||
# For Bi-encoders we create embeddings separately for predictions and labels
|
||||
similarity_model = SentenceTransformer(model, device=device.to_torch_str(), use_auth_token=token)
|
||||
pred_embeddings = similarity_model.encode(predictions, batch_size=batch_size, convert_to_tensor=True)
|
||||
label_embeddings = similarity_model.encode(labels, batch_size=batch_size, convert_to_tensor=True)
|
||||
|
||||
# Compute cosine-similarities
|
||||
scores = util.cos_sim(pred_embeddings, label_embeddings)
|
||||
|
||||
# cos_sim computes cosine similarity between all pairs of vectors in pred_embeddings and label_embeddings
|
||||
# It returns a matrix with shape (len(predictions), len(labels))
|
||||
similarity_scores = [scores[i][i].item() for i in range(len(predictions))]
|
||||
|
||||
sas_score = np.mean(similarity_scores)
|
||||
|
||||
return MetricsResult({"sas": sas_score, "scores": similarity_scores})
|
||||
|
||||
|
||||
def eval(
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
import re
|
||||
import string
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def preprocess_text(
|
||||
texts: List[str], regexes_to_ignore=None, ignore_case=False, ignore_punctuation=False, ignore_numbers=False
|
||||
texts: List[str],
|
||||
regexes_to_ignore: Optional[List[str]] = None,
|
||||
ignore_case: bool = False,
|
||||
ignore_punctuation: bool = False,
|
||||
ignore_numbers: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Preprocess the outputs of the runnable to remove unwanted characters.
|
||||
|
||||
10
releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml
Normal file
10
releasenotes/notes/add-sas-b8dbf61c0d78ba19.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Adds support for the Semantic Answer Similarity (SAS) metric to `EvaluationResult.calculate_metrics(...)`:
|
||||
```python
|
||||
from haystack.evaluation.metrics import Metric
|
||||
sas_metric = eval_result.calculate_metrics(
|
||||
Metric.SAS, output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
```
|
||||
347
test/evaluation/test_eval_sas.py
Normal file
347
test/evaluation/test_eval_sas.py
Normal file
@ -0,0 +1,347 @@
|
||||
import pytest
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.dataclasses import GeneratedAnswer
|
||||
from haystack.evaluation.eval import EvaluationResult
|
||||
|
||||
|
||||
class TestSAS:
|
||||
def create_evaluation_result(self, predictions, labels):
|
||||
"""
|
||||
Creates an evaluation result of a RAG pipeline using the list of predictions and labels for testing the
|
||||
Semantic Answer Similarity (SAS) Metric.
|
||||
"""
|
||||
runnable = Pipeline()
|
||||
inputs = []
|
||||
outputs = [
|
||||
{"answer_builder": {"answers": [GeneratedAnswer(data=pred, query="", documents=[], meta={})]}}
|
||||
for pred in predictions
|
||||
]
|
||||
expected_outputs = [
|
||||
{"answer_builder": {"answers": [GeneratedAnswer(data=label, query="", documents=[], meta={})]}}
|
||||
for label in labels
|
||||
]
|
||||
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||
return evaluation_result
|
||||
|
||||
def test_sas_empty_inputs(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with empty inputs.
|
||||
"""
|
||||
runnable = Pipeline()
|
||||
inputs = []
|
||||
outputs = [
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
]
|
||||
expected_outputs = [
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
{"answer_builder": {"answers": []}},
|
||||
]
|
||||
evaluation_result = EvaluationResult(runnable, inputs, outputs, expected_outputs)
|
||||
# Expecting 0% SAS for empty inputs
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == 0.0
|
||||
assert sas_result["scores"] == [0.0]
|
||||
|
||||
def test_calculate_sas_with_different_lengths(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with default parameters.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
]
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
|
||||
with pytest.raises(ValueError, match="The number of predictions and labels must be the same."):
|
||||
evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_same_inputs(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with default parameters.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_single_word(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with single-word inputs.
|
||||
"""
|
||||
predictions = ["A construction budget of US $2.3 billion"]
|
||||
labels = ["US $2.3 billion"]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(0.689089, abs=1e-5)
|
||||
assert sas_result["scores"] == pytest.approx([0.689089], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_negative_case(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with deliberately mismatched predictions and labels.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"US $2.3 billion",
|
||||
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
|
||||
"Japan was transformed into a modernized world power after the Meiji Restoration.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(0.8227189)
|
||||
assert sas_result["scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_ignore_case(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring case sensitivity.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 BILLION",
|
||||
"The EIFFEL TOWER, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The MEIJI RESTORATION in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# SAS after case ignoring
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2", ignore_case=True
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_ignore_punctuation(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring punctuation.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower completed in 1889 symbolizes Paris's cultural magnificence",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# SAS after ignoring punctuation
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers",
|
||||
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
ignore_punctuation=True,
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_ignore_numbers(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring numbers.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# SAS after ignoring numbers
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers",
|
||||
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
ignore_numbers=True,
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_regex_ignore(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with ignoring specific regex patterns.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $10.3 billion",
|
||||
"The Eiffel Tower, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore numeric patterns
|
||||
regex_to_ignore = [r"\d+"]
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers",
|
||||
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
regexes_to_ignore=regex_to_ignore,
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_multiple_ignore_regex(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US #10.3 billion",
|
||||
"The Eiffel Tower!!, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The **Meiji Restoration**, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore numeric patterns and punctuation excluding whitespaces
|
||||
regex_to_ignore = [r"\d+", r"[^\w\s]"]
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers",
|
||||
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
regexes_to_ignore=regex_to_ignore,
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_multiple_ignore_combination(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score with multiple ignoring parameters combined.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration, in 1868, transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US #10.3 BILLION",
|
||||
"The EIFFEL TOWER!!, completed in 2005, symbolizes Paris's cultural magnificence.",
|
||||
"The **MEIJI RESTORATION**, in 1989, transformed Japan into a modernized world power.",
|
||||
]
|
||||
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
# Ignore only special characters using regex
|
||||
regex_to_ignore = [r"[^\w\s\d]+"]
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers",
|
||||
model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
ignore_numbers=True,
|
||||
ignore_punctuation=True,
|
||||
ignore_case=True,
|
||||
regexes_to_ignore=regex_to_ignore,
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_bi_encoder(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score using a Bi-Encoder model.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="sentence-transformers/all-mpnet-base-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(1.0)
|
||||
assert sas_result["scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_sas_cross_encoder(self):
|
||||
"""
|
||||
Test calculation of Semantic Answer Similarity (SAS) Score using a Cross Encoder model.
|
||||
"""
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
labels = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluation_result = self.create_evaluation_result(predictions, labels)
|
||||
sas_result = evaluation_result._calculate_sas(
|
||||
output_key="answers", model="cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
)
|
||||
|
||||
assert sas_result["sas"] == pytest.approx(0.999967, abs=1e-5)
|
||||
assert sas_result["scores"] == pytest.approx(
|
||||
[0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user