From 12acb3f12e2ab92ff61534f55c2cb65dc64e382b Mon Sep 17 00:00:00 2001 From: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> Date: Thu, 4 Apr 2024 10:10:41 +0200 Subject: [PATCH] feat: Add `SASEvaluator` (#7428) * Add SASEvaluator * Add release notes * Apply suggestions from code review Co-authored-by: Madeesh Kannan * Simplify similarity calculation with bi-encoders models * Fix linting * Update docstrings * Move tensor to CPU after calculating cosine similarity * Fix CI failing --------- Co-authored-by: Madeesh Kannan --- .../components/evaluators/sas_evaluator.py | 187 ++++++++++++++++++ .../notes/sas-evaluator-6970865787557e83.yaml | 4 + .../evaluators/test_sas_evaluator.py | 178 +++++++++++++++++ 3 files changed, 369 insertions(+) create mode 100644 haystack/components/evaluators/sas_evaluator.py create mode 100644 releasenotes/notes/sas-evaluator-6970865787557e83.yaml create mode 100644 test/components/evaluators/test_sas_evaluator.py diff --git a/haystack/components/evaluators/sas_evaluator.py b/haystack/components/evaluators/sas_evaluator.py new file mode 100644 index 000000000..3f16f3a9a --- /dev/null +++ b/haystack/components/evaluators/sas_evaluator.py @@ -0,0 +1,187 @@ +from typing import Any, Dict, List, Optional + +from numpy import mean as np_mean + +from haystack import component, default_from_dict, default_to_dict +from haystack.lazy_imports import LazyImport +from haystack.utils import ComponentDevice, expit +from haystack.utils.auth import Secret, deserialize_secrets_inplace + +with LazyImport(message="Run 'pip install scikit-learn \"sentence-transformers>=2.2.0\"'") as sas_import: + from sentence_transformers import CrossEncoder, SentenceTransformer, util + from transformers import AutoConfig + + +@component +class SASEvaluator: + """ + SASEvaluator computes the Semantic Answer Similarity (SAS) between a list of predictions and a list of ground truths. + It's usually used in Retrieval Augmented Generation (RAG) pipelines to evaluate the quality of the generated answers. + + The SAS is computed using a pre-trained model from the Hugging Face model hub. The model can be either a + Bi-Encoder or a Cross-Encoder. The choice of the model is based on the `model` parameter. + + Usage example: + ```python + from haystack.components.evaluators.sas_evaluator import SASEvaluator + + evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2") + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + result = evaluator.run( + ground_truths_answers=ground_truths, predicted_answers=predictions + ) + + print(result["score"]) + # 0.9999673763910929 + + print(result["individual_scores"]) + # [0.9999765157699585, 0.999968409538269, 0.9999572038650513] + ``` + """ + + def __init__( + self, + model: str = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + batch_size: int = 32, + device: Optional[ComponentDevice] = None, + token: Secret = Secret.from_env_var("HF_API_TOKEN", strict=False), + ): + """ + Creates a new instance of SASEvaluator. + + :param model: + SentenceTransformers semantic textual similarity model, should be path or string pointing to a downloadable model. + :param batch_size: + Number of prediction-label pairs to encode at once. + :param device: + The device on which the model is loaded. If `None`, the default device is automatically selected. + :param token: + The Hugging Face token for HTTP bearer authorization. + You can find your HF token in your [account settings](https://huggingface.co/settings/tokens) + """ + sas_import.check() + + self._model = model + self._batch_size = batch_size + self._device = device + self._token = token + self._similarity_model = None + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: + The serialized component as a dictionary. + """ + return default_to_dict( + self, + model=self._model, + batch_size=self._batch_size, + device=self._device.to_dict() if self._device else None, + token=self._token.to_dict() if self._token else None, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SASEvaluator": + """ + Deserialize this component from a dictionary. + + :param data: + The dictionary representation of this component. + :returns: + The deserialized component instance. + """ + deserialize_secrets_inplace(data["init_parameters"], keys=["token"]) + if device := data.get("init_parameters", {}).get("device"): + data["init_parameters"]["device"] = ComponentDevice.from_dict(device) + return default_from_dict(cls, data) + + def warm_up(self): + """ + Initializes the component. + """ + token = self._token.resolve_value() if self._token else None + config = AutoConfig.from_pretrained(self._model, use_auth_token=token) + cross_encoder_used = False + if config.architectures: + cross_encoder_used = any(arch.endswith("ForSequenceClassification") for arch in config.architectures) + device = ComponentDevice.resolve_device(self._device).to_torch_str() + # Based on the Model string we can load either Bi-Encoders or Cross Encoders. + # Similarity computation changes for both approaches + if cross_encoder_used: + self._similarity_model = CrossEncoder( + self._model, + device=device, + tokenizer_args={"use_auth_token": token}, + automodel_args={"use_auth_token": token}, + ) + else: + self._similarity_model = SentenceTransformer(self._model, device=device, use_auth_token=token) + + @component.output_types(score=float, individual_scores=List[float]) + def run(self, ground_truths_answers: List[str], predicted_answers: List[str]) -> Dict[str, Any]: + """ + Run the SASEvaluator to compute the Semantic Answer Similarity (SAS) between a list of predicted answers + and a list of ground truth answers. Both must be list of strings of same length. + + :param ground_truth_answers: + A list of expected answers for each question. + :param predicted_answers: + A list of generated answers for each question. + :returns: + A dictionary with the following outputs: + - `score`: Mean SAS score over all the predictions/ground-truth pairs. + - `individual_scores`: A list of similarity scores for each prediction/ground-truth pair. + """ + if len(ground_truths_answers) != len(predicted_answers): + raise ValueError("The number of predictions and labels must be the same.") + + if len(predicted_answers) == 0: + return {"score": 0.0, "individual_scores": [0.0]} + + if not self._similarity_model: + msg = "The model has not been initialized. Call warm_up() before running the evaluator." + raise RuntimeError(msg) + + if isinstance(self._similarity_model, CrossEncoder): + # For Cross Encoders we create a list of pairs of predictions and labels + sentence_pairs = list(zip(predicted_answers, ground_truths_answers)) + similarity_scores = self._similarity_model.predict( + sentence_pairs, batch_size=self._batch_size, convert_to_numpy=True + ) + + # All Cross Encoders do not return a set of logits scores that are normalized + # We normalize scores if they are larger than 1 + if (similarity_scores > 1).any(): + similarity_scores = expit(similarity_scores) + + # Convert scores to list of floats from numpy array + similarity_scores = similarity_scores.tolist() + + else: + # For Bi-encoders we create embeddings separately for predictions and labels + predictions_embeddings = self._similarity_model.encode( + predicted_answers, batch_size=self._batch_size, convert_to_tensor=True + ) + label_embeddings = self._similarity_model.encode( + ground_truths_answers, batch_size=self._batch_size, convert_to_tensor=True + ) + + # Compute cosine-similarities + similarity_scores = [ + util.cos_sim(p, l).cpu().numpy() for p, l in zip(predictions_embeddings, label_embeddings) + ] + + sas_score = np_mean(similarity_scores) + + return {"score": sas_score, "individual_scores": similarity_scores} diff --git a/releasenotes/notes/sas-evaluator-6970865787557e83.yaml b/releasenotes/notes/sas-evaluator-6970865787557e83.yaml new file mode 100644 index 000000000..d7aa24c73 --- /dev/null +++ b/releasenotes/notes/sas-evaluator-6970865787557e83.yaml @@ -0,0 +1,4 @@ +--- +features: + - | + Add SASEvaluator, it can be used to calculate Semantic Answer Similarity of generated answers from an LLM diff --git a/test/components/evaluators/test_sas_evaluator.py b/test/components/evaluators/test_sas_evaluator.py new file mode 100644 index 000000000..555aa171a --- /dev/null +++ b/test/components/evaluators/test_sas_evaluator.py @@ -0,0 +1,178 @@ +import pytest + +from haystack.components.evaluators.sas_evaluator import SASEvaluator +from haystack.utils.device import ComponentDevice + + +class TestSASEvaluator: + def test_init_default(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + evaluator = SASEvaluator() + + assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + assert evaluator._batch_size == 32 + assert evaluator._device is None + assert evaluator._token.resolve_value() == "fake-token" + + def test_to_dict(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + + evaluator = SASEvaluator(device=ComponentDevice.from_str("cuda:0")) + + expected_dict = { + "type": "haystack.components.evaluators.sas_evaluator.SASEvaluator", + "init_parameters": { + "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + "batch_size": 32, + "device": {"type": "single", "device": "cuda:0"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False}, + }, + } + assert evaluator.to_dict() == expected_dict + + def test_from_dict(self, monkeypatch): + monkeypatch.setenv("HF_API_TOKEN", "fake-token") + evaluator = SASEvaluator.from_dict( + { + "type": "haystack.components.evaluators.sas_evaluator.SASEvaluator", + "init_parameters": { + "model": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", + "batch_size": 32, + "device": {"type": "single", "device": "cuda:0"}, + "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN"], "strict": False}, + }, + } + ) + + assert evaluator._model == "sentence-transformers/paraphrase-multilingual-mpnet-base-v2" + assert evaluator._batch_size == 32 + assert evaluator._device.to_torch_str() == "cuda:0" + assert evaluator._token.resolve_value() == "fake-token" + + def test_run_with_empty_inputs(self): + evaluator = SASEvaluator() + result = evaluator.run(ground_truths_answers=[], predicted_answers=[]) + assert len(result) == 2 + assert result["score"] == 0.0 + assert result["individual_scores"] == [0.0] + + def test_run_with_different_lengths(self): + evaluator = SASEvaluator() + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + with pytest.raises(ValueError): + evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + + def test_run_not_warmed_up(self): + evaluator = SASEvaluator() + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + with pytest.raises(RuntimeError): + evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + + @pytest.mark.integration + def test_run_with_matching_predictions(self): + evaluator = SASEvaluator() + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator.warm_up() + result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + + assert len(result) == 2 + assert result["score"] == pytest.approx(1.0) + assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_single_prediction(self): + evaluator = SASEvaluator() + + ground_truths = ["US $2.3 billion"] + evaluator.warm_up() + result = evaluator.run( + ground_truths_answers=ground_truths, predicted_answers=["A construction budget of US $2.3 billion"] + ) + assert len(result) == 2 + assert result["score"] == pytest.approx(0.689089, abs=1e-5) + assert result["individual_scores"] == pytest.approx([0.689089], abs=1e-5) + + @pytest.mark.integration + def test_run_with_mismatched_predictions(self): + evaluator = SASEvaluator() + ground_truths = [ + "US $2.3 billion", + "Paris's cultural magnificence is symbolized by the Eiffel Tower", + "Japan was transformed into a modernized world power after the Meiji Restoration.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator.warm_up() + result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + assert len(result) == 2 + assert result["score"] == pytest.approx(0.8227189) + assert result["individual_scores"] == pytest.approx([0.689089, 0.870389, 0.908679], abs=1e-5) + + @pytest.mark.integration + def test_run_with_bi_encoder_model(self): + evaluator = SASEvaluator(model="sentence-transformers/all-mpnet-base-v2") + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator.warm_up() + result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + assert len(result) == 2 + assert result["score"] == pytest.approx(1.0) + assert result["individual_scores"] == pytest.approx([1.0, 1.0, 1.0]) + + @pytest.mark.integration + def test_run_with_cross_encoder_model(self): + evaluator = SASEvaluator(model="cross-encoder/ms-marco-MiniLM-L-6-v2") + ground_truths = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + predictions = [ + "A construction budget of US $2.3 billion", + "The Eiffel Tower, completed in 1889, symbolizes Paris's cultural magnificence.", + "The Meiji Restoration in 1868 transformed Japan into a modernized world power.", + ] + evaluator.warm_up() + result = evaluator.run(ground_truths_answers=ground_truths, predicted_answers=predictions) + assert len(result) == 2 + assert result["score"] == pytest.approx(0.999967, abs=1e-5) + assert result["individual_scores"] == pytest.approx( + [0.9999765157699585, 0.999968409538269, 0.9999572038650513], abs=1e-5 + )