feat: Add SASEvaluator (#7428)

* Add SASEvaluator

* Add release notes

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Simplify similarity calculation with bi-encoders models

* Fix linting

* Update docstrings

* Move tensor to CPU after calculating cosine similarity

* Fix CI failing

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
Silvano Cerza 2024-04-04 10:10:41 +02:00 committed by GitHub
parent 1c7d1618d8
commit 12acb3f12e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 369 additions and 0 deletions

View File

@ -0,0 +1,187 @@
from typing import Any, Dict, List, Optional
from numpy import mean as np_mean
from haystack import component, default_from_dict, default_to_dict
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, expit
from haystack.utils.auth import Secret, deserialize_secrets_inplace
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import:
from sentence_transformers import CrossEncoder, SentenceTransformer, util
from transformers import AutoConfig
@component
class SASEvaluator:
"""
SASEvaluator computes the Semantic Answer Similarity (SAS) between a list of predictions and a list of ground truths.
It's usually used in Retrieval Augmented Generation (RAG) pipelines to evaluate the quality of the generated answers.
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
Usage example:
```python
from haystack.components.evaluators.sas_evaluator import SASEvaluator
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
result = evaluator.run(
ground_truths_answers=ground_truths, predicted_answers=predictions
)
print(result["score"])
# 0.9999673763910929
print(result["individual_scores"])
# [0.9999765157699585, 0.999968409538269, 0.9999572038650513]
```
"""
def __init__(
self,
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
batch_size: int = 32,
device: Optional[ComponentDevice] = None,
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
):
"""
Creates a new instance of SASEvaluator.
:param model:
SentenceTransformers semantic textual similarity model, should be path or string pointing to a downloadable model.
:param batch_size:
Number of prediction-label pairs to encode at once.
:param device:
The device on which the model is loaded. If `None`, the default device is automatically selected.
:param token:
The Hugging Face token for HTTP bearer authorization.
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
"""
sas_import.check()
self._model = model
self._batch_size = batch_size
self._device = device
self._token = token
self._similarity_model = None
def to_dict(self) -> Dict[str, Any]:
"""
Serialize this component to a dictionary.
:returns:
The serialized component as a dictionary.
"""
return default_to_dict(
self,
model=self._model,
batch_size=self._batch_size,
device=self._device.to_dict() if self._device else None,
token=self._token.to_dict() if self._token else None,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SASEvaluator":
"""
Deserialize this component from a dictionary.
:param data:
The dictionary representation of this component.
:returns:
The deserialized component instance.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
if device := data.get("init_parameters", {}).get("device"):
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
return default_from_dict(cls, data)
def warm_up(self):
"""
Initializes the component.
"""
token = self._token.resolve_value() if self._token else None
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
cross_encoder_used = False
if config.architectures:
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
device = ComponentDevice.resolve_device(self._device).to_torch_str()
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
# Similarity computation changes for both approaches
if cross_encoder_used:
self._similarity_model = CrossEncoder(
self._model,
device=device,
tokenizer_args={"use_auth_token": token},
automodel_args={"use_auth_token": token},
)
else:
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
@component.output_types(score=float, individual_scores=List[float])
def run(self, ground_truths_answers: List[str], predicted_answers: List[str]) -> Dict[str, Any]:
"""
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predicted answers
and a list of ground truth answers. Both must be list of strings of same length.
:param ground_truth_answers:
A list of expected answers for each question.
:param predicted_answers:
A list of generated answers for each question.
:returns:
A dictionary with the following outputs:
- `score`: Mean SAS score over all the predictions/ground-truth pairs.
- `individual_scores`: A list of similarity scores for each prediction/ground-truth pair.
"""
if len(ground_truths_answers) != len(predicted_answers):
raise ValueError("The number of predictions and labels must be the same.")
if len(predicted_answers) == 0:
return {"score": 0.0, "individual_scores": [0.0]}
if not self._similarity_model:
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
raise RuntimeError(msg)
if isinstance(self._similarity_model, CrossEncoder):
# For Cross Encoders we create a list of pairs of predictions and labels
sentence_pairs = list(zip(predicted_answers, ground_truths_answers))
similarity_scores = self._similarity_model.predict(
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
)
# All Cross Encoders do not return a set of logits scores that are normalized
# We normalize scores if they are larger than 1
if (similarity_scores > 1).any():
similarity_scores = expit(similarity_scores)
# Convert scores to list of floats from numpy array
similarity_scores = similarity_scores.tolist()
else:
# For Bi-encoders we create embeddings separately for predictions and labels
predictions_embeddings = self._similarity_model.encode(
predicted_answers, batch_size=self._batch_size, convert_to_tensor=True
)
label_embeddings = self._similarity_model.encode(
ground_truths_answers, batch_size=self._batch_size, convert_to_tensor=True
)
# Compute cosine-similarities
similarity_scores = [
util.cos_sim(p, l).cpu().numpy() for p, l in zip(predictions_embeddings, label_embeddings)
]
sas_score = np_mean(similarity_scores)
return {"score": sas_score, "individual_scores": similarity_scores}

View File

@ -0,0 +1,4 @@
---
features:
- |
Add SASEvaluator, it can be used to calculate Semantic Answer Similarity of generated answers from an LLM

View File

@ -0,0 +1,178 @@
import pytest
from haystack.components.evaluators.sas_evaluator import SASEvaluator
from haystack.utils.device import ComponentDevice
class TestSASEvaluator:
def test_init_default(self, monkeypatch):
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
evaluator = SASEvaluator()
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
assert evaluator._batch_size == 32
assert evaluator._device is None
assert evaluator._token.resolve_value() == "fake-token"
def test_to_dict(self, monkeypatch):
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0"))
expected_dict = {
"type": "haystack.components.evaluators.sas_evaluator.SASEvaluator",
"init_parameters": {
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"batch_size": 32,
"device": {"type": "single", "device": "cuda:0"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False},
},
}
assert evaluator.to_dict() == expected_dict
def test_from_dict(self, monkeypatch):
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
evaluator = SASEvaluator.from_dict(
{
"type": "haystack.components.evaluators.sas_evaluator.SASEvaluator",
"init_parameters": {
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
"batch_size": 32,
"device": {"type": "single", "device": "cuda:0"},
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False},
},
}
)
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
assert evaluator._batch_size == 32
assert evaluator._device.to_torch_str() == "cuda:0"
assert evaluator._token.resolve_value() == "fake-token"
def test_run_with_empty_inputs(self):
evaluator = SASEvaluator()
result = evaluator.run(ground_truths_answers=[], predicted_answers=[])
assert len(result) == 2
assert result["score"] == 0.0
assert result["individual_scores"] == [0.0]
def test_run_with_different_lengths(self):
evaluator = SASEvaluator()
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
with pytest.raises(ValueError):
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
def test_run_not_warmed_up(self):
evaluator = SASEvaluator()
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
with pytest.raises(RuntimeError):
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
@pytest.mark.integration
def test_run_with_matching_predictions(self):
evaluator = SASEvaluator()
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator.warm_up()
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
assert len(result) == 2
assert result["score"] == pytest.approx(1.0)
assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_single_prediction(self):
evaluator = SASEvaluator()
ground_truths = ["US $2.3 billion"]
evaluator.warm_up()
result = evaluator.run(
ground_truths_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"]
)
assert len(result) == 2
assert result["score"] == pytest.approx(0.689089, abs=1e-5)
assert result["individual_scores"] == pytest.approx([0.689089], abs=1e-5)
@pytest.mark.integration
def test_run_with_mismatched_predictions(self):
evaluator = SASEvaluator()
ground_truths = [
"US $2.3 billion",
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
"Japan was transformed into a modernized world power after the Meiji Restoration.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator.warm_up()
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
assert len(result) == 2
assert result["score"] == pytest.approx(0.8227189)
assert result["individual_scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
@pytest.mark.integration
def test_run_with_bi_encoder_model(self):
evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2")
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator.warm_up()
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
assert len(result) == 2
assert result["score"] == pytest.approx(1.0)
assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0])
@pytest.mark.integration
def test_run_with_cross_encoder_model(self):
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
ground_truths = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
predictions = [
"A construction budget of US $2.3 billion",
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
]
evaluator.warm_up()
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
assert len(result) == 2
assert result["score"] == pytest.approx(0.999967, abs=1e-5)
assert result["individual_scores"] == pytest.approx(
[0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5
)