diff --git a/haystack/components/rankers/transformers_similarity.py b/haystack/components/rankers/transformers_similarity.py index d22b043b5..6dc163f5a 100644 --- a/haystack/components/rankers/transformers_similarity.py +++ b/haystack/components/rankers/transformers_similarity.py @@ -43,6 +43,9 @@ class TransformersSimilarityRanker: top_k: int = 10, meta_fields_to_embed: Optional[List[str]] = None, embedding_separator: str = "\n", + scale_score: bool = True, + calibration_factor: Optional[float] = 1.0, + score_threshold: Optional[float] = None, model_kwargs: Optional[Dict[str, Any]] = None, ): """ @@ -57,6 +60,11 @@ class TransformersSimilarityRanker: :param top_k: The maximum number of Documents to return per query. :param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content. :param embedding_separator: Separator used to concatenate the meta fields to the Document content. + :param scale_score: Whether the raw logit predictions will be scaled using a Sigmoid activation function. + Set this to False if you do not want any scaling of the raw logit predictions. + :param calibration_factor: Factor used for calibrating probabilities calculated by + `sigmoid(logits * calibration_factor)`. This is only used if `scale_score` is set to True. + :param score_threshold: If provided only returns documents with a score above this threshold. :param model_kwargs: Additional keyword arguments passed to `AutoModelForSequenceClassification.from_pretrained` when loading the model specified in `model_name_or_path`. For details on what kwargs you can pass, see the model's documentation. @@ -73,6 +81,13 @@ class TransformersSimilarityRanker: self.tokenizer = None self.meta_fields_to_embed = meta_fields_to_embed or [] self.embedding_separator = embedding_separator + self.scale_score = scale_score + self.calibration_factor = calibration_factor + if self.scale_score and self.calibration_factor is None: + raise ValueError( + f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}" + ) + self.score_threshold = score_threshold self.model_kwargs = model_kwargs or {} def _get_telemetry_data(self) -> Dict[str, Any]: @@ -106,28 +121,53 @@ class TransformersSimilarityRanker: top_k=self.top_k, meta_fields_to_embed=self.meta_fields_to_embed, embedding_separator=self.embedding_separator, + scale_score=self.scale_score, + calibration_factor=self.calibration_factor, + score_threshold=self.score_threshold, model_kwargs=self.model_kwargs, ) @component.output_types(documents=List[Document]) - def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): + def run( + self, + query: str, + documents: List[Document], + top_k: Optional[int] = None, + scale_score: Optional[bool] = None, + calibration_factor: Optional[float] = None, + score_threshold: Optional[float] = None, + ): """ Returns a list of Documents ranked by their similarity to the given query. :param query: Query string. :param documents: List of Documents. :param top_k: The maximum number of Documents you want the Ranker to return. + :param scale_score: Whether the raw logit predictions will be scaled using a Sigmoid activation function. + Set this to False if you do not want any scaling of the raw logit predictions. + :param calibration_factor: Factor used for calibrating probabilities calculated by + `sigmoid(logits * calibration_factor)`. This is only used if `scale_score` is set to True. + :param score_threshold: If provided only returns documents with a score above this threshold. :return: List of Documents sorted by their similarity to the query with the most similar Documents appearing first. """ if not documents: return {"documents": []} - if top_k is None: - top_k = self.top_k - - elif top_k <= 0: + top_k = top_k or self.top_k + if top_k <= 0: raise ValueError(f"top_k must be > 0, but got {top_k}") + scale_score = scale_score or self.scale_score + calibration_factor = calibration_factor or self.calibration_factor + + if scale_score and calibration_factor is None: + raise ValueError( + f"scale_score is True so calibration_factor must be provided, but got {calibration_factor}" + ) + + if score_threshold is None: + score_threshold = self.score_threshold + # If a model path is provided but the model isn't loaded if self.model_name_or_path and not self.model: raise ComponentError( @@ -150,10 +190,20 @@ class TransformersSimilarityRanker: with torch.inference_mode(): similarity_scores = self.model(**features).logits.squeeze(dim=1) # type: ignore + if scale_score: + similarity_scores = torch.sigmoid(similarity_scores * calibration_factor) + _, sorted_indices = torch.sort(similarity_scores, descending=True) + + sorted_indices = sorted_indices.cpu().tolist() # type: ignore + similarity_scores = similarity_scores.cpu().tolist() ranked_docs = [] - for sorted_index_tensor in sorted_indices: - i = sorted_index_tensor.item() - documents[i].score = similarity_scores[i].item() + for sorted_index in sorted_indices: + i = sorted_index + documents[i].score = similarity_scores[i] ranked_docs.append(documents[i]) + + if score_threshold is not None: + ranked_docs = [doc for doc in ranked_docs if doc.score >= score_threshold] + return {"documents": ranked_docs[:top_k]} diff --git a/releasenotes/notes/scale-score-similarity-ranker-2deacff999265b9e.yaml b/releasenotes/notes/scale-score-similarity-ranker-2deacff999265b9e.yaml new file mode 100644 index 000000000..8b7f26d29 --- /dev/null +++ b/releasenotes/notes/scale-score-similarity-ranker-2deacff999265b9e.yaml @@ -0,0 +1,6 @@ +--- +enhancements: + - | + Adds scale_score, which allows users to toggle if they would like their document scores to be raw logits or scaled between 0 and 1 (using the sigmoid function). This is a feature that already existed in Haystack v1 that is being moved over. + Adds calibration_factor. This follows the example from the ExtractiveReader which allows the user to better control the spread of scores when scaling the score using sigmoid. + Adds score_threshold. Also copied from the ExtractiveReader. This optionally allows users to set a score threshold where only documents with a score above this threshold are returned. diff --git a/test/components/rankers/test_transformers_similarity.py b/test/components/rankers/test_transformers_similarity.py index 06bdf59f6..21566c464 100644 --- a/test/components/rankers/test_transformers_similarity.py +++ b/test/components/rankers/test_transformers_similarity.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest import torch +from transformers.modeling_outputs import SequenceClassifierOutput from haystack import Document, ComponentError from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker @@ -19,6 +20,9 @@ class TestSimilarityRanker: "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", "meta_fields_to_embed": [], "embedding_separator": "\n", + "scale_score": True, + "calibration_factor": 1.0, + "score_threshold": None, "model_kwargs": {}, }, } @@ -29,6 +33,9 @@ class TestSimilarityRanker: device="cuda", token="my_token", top_k=5, + scale_score=False, + calibration_factor=None, + score_threshold=0.01, model_kwargs={"torch_dtype": "auto"}, ) data = component.to_dict() @@ -41,13 +48,18 @@ class TestSimilarityRanker: "top_k": 5, "meta_fields_to_embed": [], "embedding_separator": "\n", + "scale_score": False, + "calibration_factor": None, + "score_threshold": 0.01, "model_kwargs": {"torch_dtype": "auto"}, }, } + @patch("torch.sigmoid") @patch("torch.sort") - def test_embed_meta(self, mocked_sort): + def test_embed_meta(self, mocked_sort, mocked_sigmoid): mocked_sort.return_value = (None, torch.tensor([0])) + mocked_sigmoid.return_value = torch.tensor([0]) embedder = TransformersSimilarityRanker( model_name_or_path="model", meta_fields_to_embed=["meta_field"], embedding_separator="\n" ) @@ -71,16 +83,60 @@ class TestSimilarityRanker: return_tensors="pt", ) + @patch("torch.sort") + def test_scale_score_false(self, mocked_sort): + mocked_sort.return_value = (None, torch.tensor([0, 1])) + embedder = TransformersSimilarityRanker(model_name_or_path="model", scale_score=False) + embedder.model = MagicMock() + embedder.model.return_value = SequenceClassifierOutput( + loss=None, logits=torch.FloatTensor([[-10.6859], [-8.9874]]), hidden_states=None, attentions=None + ) + embedder.tokenizer = MagicMock() + + documents = [Document(content="document number 0"), Document(content="document number 1")] + out = embedder.run(query="test", documents=documents) + assert out["documents"][0].score == pytest.approx(-10.6859, abs=1e-4) + assert out["documents"][1].score == pytest.approx(-8.9874, abs=1e-4) + + @patch("torch.sort") + def test_score_threshold(self, mocked_sort): + mocked_sort.return_value = (None, torch.tensor([0, 1])) + embedder = TransformersSimilarityRanker(model_name_or_path="model", scale_score=False, score_threshold=0.1) + embedder.model = MagicMock() + embedder.model.return_value = SequenceClassifierOutput( + loss=None, logits=torch.FloatTensor([[0.955], [0.001]]), hidden_states=None, attentions=None + ) + embedder.tokenizer = MagicMock() + + documents = [Document(content="document number 0"), Document(content="document number 1")] + out = embedder.run(query="test", documents=documents) + assert len(out["documents"]) == 1 + @pytest.mark.integration @pytest.mark.parametrize( - "query,docs_before_texts,expected_first_text", + "query,docs_before_texts,expected_first_text,scores", [ - ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), - ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), - ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), + ( + "City in Bosnia and Herzegovina", + ["Berlin", "Belgrade", "Sarajevo"], + "Sarajevo", + [2.2864143829792738e-05, 0.00012495707778725773, 0.009869757108390331], + ), + ( + "Machine learning", + ["Python", "Bakery in Paris", "Tesla Giga Berlin"], + "Python", + [1.9063229046878405e-05, 1.434577916370472e-05, 1.3049247172602918e-05], + ), + ( + "Cubist movement", + ["Nirvana", "Pablo Picasso", "Coffee"], + "Pablo Picasso", + [1.3313065210240893e-05, 9.90335684036836e-05, 1.3518535524781328e-05], + ), ], ) - def test_run(self, query, docs_before_texts, expected_first_text): + def test_run(self, query, docs_before_texts, expected_first_text, scores): """ Test if the component ranks documents correctly. """ @@ -93,8 +149,10 @@ class TestSimilarityRanker: assert len(docs_after) == 3 assert docs_after[0].content == expected_first_text - sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) - assert [doc.score for doc in docs_after] == sorted_scores + sorted_scores = sorted(scores, reverse=True) + assert docs_after[0].score == pytest.approx(sorted_scores[0], abs=1e-6) + assert docs_after[1].score == pytest.approx(sorted_scores[1], abs=1e-6) + assert docs_after[2].score == pytest.approx(sorted_scores[2], abs=1e-6) # Returns an empty list if no documents are provided @pytest.mark.integration