mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 13:06:29 +00:00
feat: Add SASEvaluator (#7428)
* Add SASEvaluator * Add release notes * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Simplify similarity calculation with bi-encoders models * Fix linting * Update docstrings * Move tensor to CPU after calculating cosine similarity * Fix CI failing --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
This commit is contained in:
parent
1c7d1618d8
commit
12acb3f12e
187
haystack/components/evaluators/sas_evaluator.py
Normal file
187
haystack/components/evaluators/sas_evaluator.py
Normal file
@ -0,0 +1,187 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from numpy import mean as np_mean
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, expit
|
||||
from haystack.utils.auth import Secret, deserialize_secrets_inplace
|
||||
|
||||
with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import:
|
||||
from sentence_transformers import CrossEncoder, SentenceTransformer, util
|
||||
from transformers import AutoConfig
|
||||
|
||||
|
||||
@component
|
||||
class SASEvaluator:
|
||||
"""
|
||||
SASEvaluator computes the Semantic Answer Similarity (SAS) between a list of predictions and a list of ground truths.
|
||||
It's usually used in Retrieval Augmented Generation (RAG) pipelines to evaluate the quality of the generated answers.
|
||||
|
||||
The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a
|
||||
Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack.components.evaluators.sas_evaluator import SASEvaluator
|
||||
|
||||
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
result = evaluator.run(
|
||||
ground_truths_answers=ground_truths, predicted_answers=predictions
|
||||
)
|
||||
|
||||
print(result["score"])
|
||||
# 0.9999673763910929
|
||||
|
||||
print(result["individual_scores"])
|
||||
# [0.9999765157699585, 0.999968409538269, 0.9999572038650513]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
batch_size: int = 32,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
):
|
||||
"""
|
||||
Creates a new instance of SASEvaluator.
|
||||
|
||||
:param model:
|
||||
SentenceTransformers semantic textual similarity model, should be path or string pointing to a downloadable model.
|
||||
:param batch_size:
|
||||
Number of prediction-label pairs to encode at once.
|
||||
:param device:
|
||||
The device on which the model is loaded. If `None`, the default device is automatically selected.
|
||||
:param token:
|
||||
The Hugging Face token for HTTP bearer authorization.
|
||||
You can find your HF token in your [account settings](https://huggingface.co/settings/tokens)
|
||||
"""
|
||||
sas_import.check()
|
||||
|
||||
self._model = model
|
||||
self._batch_size = batch_size
|
||||
self._device = device
|
||||
self._token = token
|
||||
self._similarity_model = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize this component to a dictionary.
|
||||
|
||||
:returns:
|
||||
The serialized component as a dictionary.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self._model,
|
||||
batch_size=self._batch_size,
|
||||
device=self._device.to_dict() if self._device else None,
|
||||
token=self._token.to_dict() if self._token else None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SASEvaluator":
|
||||
"""
|
||||
Deserialize this component from a dictionary.
|
||||
|
||||
:param data:
|
||||
The dictionary representation of this component.
|
||||
:returns:
|
||||
The deserialized component instance.
|
||||
"""
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
if device := data.get("init_parameters", {}).get("device"):
|
||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(device)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
token = self._token.resolve_value() if self._token else None
|
||||
config = AutoConfig.from_pretrained(self._model, use_auth_token=token)
|
||||
cross_encoder_used = False
|
||||
if config.architectures:
|
||||
cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures)
|
||||
device = ComponentDevice.resolve_device(self._device).to_torch_str()
|
||||
# Based on the Model string we can load either Bi-Encoders or Cross Encoders.
|
||||
# Similarity computation changes for both approaches
|
||||
if cross_encoder_used:
|
||||
self._similarity_model = CrossEncoder(
|
||||
self._model,
|
||||
device=device,
|
||||
tokenizer_args={"use_auth_token": token},
|
||||
automodel_args={"use_auth_token": token},
|
||||
)
|
||||
else:
|
||||
self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token)
|
||||
|
||||
@component.output_types(score=float, individual_scores=List[float])
|
||||
def run(self, ground_truths_answers: List[str], predicted_answers: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predicted answers
|
||||
and a list of ground truth answers. Both must be list of strings of same length.
|
||||
|
||||
:param ground_truth_answers:
|
||||
A list of expected answers for each question.
|
||||
:param predicted_answers:
|
||||
A list of generated answers for each question.
|
||||
:returns:
|
||||
A dictionary with the following outputs:
|
||||
- `score`: Mean SAS score over all the predictions/ground-truth pairs.
|
||||
- `individual_scores`: A list of similarity scores for each prediction/ground-truth pair.
|
||||
"""
|
||||
if len(ground_truths_answers) != len(predicted_answers):
|
||||
raise ValueError("The number of predictions and labels must be the same.")
|
||||
|
||||
if len(predicted_answers) == 0:
|
||||
return {"score": 0.0, "individual_scores": [0.0]}
|
||||
|
||||
if not self._similarity_model:
|
||||
msg = "The model has not been initialized. Call warm_up() before running the evaluator."
|
||||
raise RuntimeError(msg)
|
||||
|
||||
if isinstance(self._similarity_model, CrossEncoder):
|
||||
# For Cross Encoders we create a list of pairs of predictions and labels
|
||||
sentence_pairs = list(zip(predicted_answers, ground_truths_answers))
|
||||
similarity_scores = self._similarity_model.predict(
|
||||
sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True
|
||||
)
|
||||
|
||||
# All Cross Encoders do not return a set of logits scores that are normalized
|
||||
# We normalize scores if they are larger than 1
|
||||
if (similarity_scores > 1).any():
|
||||
similarity_scores = expit(similarity_scores)
|
||||
|
||||
# Convert scores to list of floats from numpy array
|
||||
similarity_scores = similarity_scores.tolist()
|
||||
|
||||
else:
|
||||
# For Bi-encoders we create embeddings separately for predictions and labels
|
||||
predictions_embeddings = self._similarity_model.encode(
|
||||
predicted_answers, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
label_embeddings = self._similarity_model.encode(
|
||||
ground_truths_answers, batch_size=self._batch_size, convert_to_tensor=True
|
||||
)
|
||||
|
||||
# Compute cosine-similarities
|
||||
similarity_scores = [
|
||||
util.cos_sim(p, l).cpu().numpy() for p, l in zip(predictions_embeddings, label_embeddings)
|
||||
]
|
||||
|
||||
sas_score = np_mean(similarity_scores)
|
||||
|
||||
return {"score": sas_score, "individual_scores": similarity_scores}
|
||||
4
releasenotes/notes/sas-evaluator-6970865787557e83.yaml
Normal file
4
releasenotes/notes/sas-evaluator-6970865787557e83.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add SASEvaluator, it can be used to calculate Semantic Answer Similarity of generated answers from an LLM
|
||||
178
test/components/evaluators/test_sas_evaluator.py
Normal file
178
test/components/evaluators/test_sas_evaluator.py
Normal file
@ -0,0 +1,178 @@
|
||||
import pytest
|
||||
|
||||
from haystack.components.evaluators.sas_evaluator import SASEvaluator
|
||||
from haystack.utils.device import ComponentDevice
|
||||
|
||||
|
||||
class TestSASEvaluator:
|
||||
def test_init_default(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||
evaluator = SASEvaluator()
|
||||
|
||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
assert evaluator._batch_size == 32
|
||||
assert evaluator._device is None
|
||||
assert evaluator._token.resolve_value() == "fake-token"
|
||||
|
||||
def test_to_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||
|
||||
evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0"))
|
||||
|
||||
expected_dict = {
|
||||
"type": "haystack.components.evaluators.sas_evaluator.SASEvaluator",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
"batch_size": 32,
|
||||
"device": {"type": "single", "device": "cuda:0"},
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False},
|
||||
},
|
||||
}
|
||||
assert evaluator.to_dict() == expected_dict
|
||||
|
||||
def test_from_dict(self, monkeypatch):
|
||||
monkeypatch.setenv("HF_API_TOKEN", "fake-token")
|
||||
evaluator = SASEvaluator.from_dict(
|
||||
{
|
||||
"type": "haystack.components.evaluators.sas_evaluator.SASEvaluator",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
|
||||
"batch_size": 32,
|
||||
"device": {"type": "single", "device": "cuda:0"},
|
||||
"token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
|
||||
assert evaluator._batch_size == 32
|
||||
assert evaluator._device.to_torch_str() == "cuda:0"
|
||||
assert evaluator._token.resolve_value() == "fake-token"
|
||||
|
||||
def test_run_with_empty_inputs(self):
|
||||
evaluator = SASEvaluator()
|
||||
result = evaluator.run(ground_truths_answers=[], predicted_answers=[])
|
||||
assert len(result) == 2
|
||||
assert result["score"] == 0.0
|
||||
assert result["individual_scores"] == [0.0]
|
||||
|
||||
def test_run_with_different_lengths(self):
|
||||
evaluator = SASEvaluator()
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
def test_run_not_warmed_up(self):
|
||||
evaluator = SASEvaluator()
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
with pytest.raises(RuntimeError):
|
||||
evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_matching_predictions(self):
|
||||
evaluator = SASEvaluator()
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(1.0)
|
||||
assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_single_prediction(self):
|
||||
evaluator = SASEvaluator()
|
||||
|
||||
ground_truths = ["US $2.3 billion"]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(
|
||||
ground_truths_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"]
|
||||
)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.689089, abs=1e-5)
|
||||
assert result["individual_scores"] == pytest.approx([0.689089], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_mismatched_predictions(self):
|
||||
evaluator = SASEvaluator()
|
||||
ground_truths = [
|
||||
"US $2.3 billion",
|
||||
"Paris's cultural magnificence is symbolized by the Eiffel Tower",
|
||||
"Japan was transformed into a modernized world power after the Meiji Restoration.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.8227189)
|
||||
assert result["individual_scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5)
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_bi_encoder_model(self):
|
||||
evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2")
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(1.0)
|
||||
assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0])
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_run_with_cross_encoder_model(self):
|
||||
evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||
ground_truths = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
predictions = [
|
||||
"A construction budget of US $2.3 billion",
|
||||
"The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.",
|
||||
"The Meiji Restoration in 1868 transformed Japan into a modernized world power.",
|
||||
]
|
||||
evaluator.warm_up()
|
||||
result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions)
|
||||
assert len(result) == 2
|
||||
assert result["score"] == pytest.approx(0.999967, abs=1e-5)
|
||||
assert result["individual_scores"] == pytest.approx(
|
||||
[0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5
|
||||
)
|
||||
Loading…
x
Reference in New Issue
Block a user