diff --git a/docs/_src/api/api/ranker.md b/docs/_src/api/api/ranker.md index 39253dbb9..b5efcbe2d 100644 --- a/docs/_src/api/api/ranker.md +++ b/docs/_src/api/api/ranker.md @@ -92,7 +92,7 @@ p.add_node(component=ranker, name="Ranker", inputs=["ESRetriever"]) #### SentenceTransformersRanker.\_\_init\_\_ ```python -def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, batch_size: Optional[int] = None) +def __init__(model_name_or_path: Union[str, Path], model_version: Optional[str] = None, top_k: int = 10, use_gpu: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, batch_size: Optional[int] = None, scale_score: bool = True) ``` **Arguments**: @@ -108,6 +108,9 @@ The strings will be converted into pytorch devices, so use the string notation d https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device (e.g. ["cuda:0"]). - `batch_size`: Number of documents to process at a time. +- `scale_score`: The raw predictions will be transformed using a Sigmoid activation function in case the model +only predicts a single label. For multi-label predictions, no scaling is applied. Set this +to False if you do not want any scaling of the raw predictions. diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index ff5511656..866b31049 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -3775,6 +3775,11 @@ "batch_size": { "title": "Batch Size", "type": "integer" + }, + "scale_score": { + "title": "Scale Score", + "default": true, + "type": "boolean" } }, "required": [ diff --git a/haystack/nodes/ranker/sentence_transformers.py b/haystack/nodes/ranker/sentence_transformers.py index 4e052a010..cb8d7a2f9 100644 --- a/haystack/nodes/ranker/sentence_transformers.py +++ b/haystack/nodes/ranker/sentence_transformers.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union, Tuple, Iterator +from typing import List, Optional, Union, Tuple, Iterator, Any import logging from pathlib import Path @@ -44,6 +44,7 @@ class SentenceTransformersRanker(BaseRanker): use_gpu: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, batch_size: Optional[int] = None, + scale_score: bool = True, ): """ :param model_name_or_path: Directory of a saved model or the name of a public model e.g. @@ -57,6 +58,9 @@ class SentenceTransformersRanker(BaseRanker): https://pytorch.org/docs/stable/tensor_attributes.html?highlight=torch%20device#torch.torch.device (e.g. ["cuda:0"]). :param batch_size: Number of documents to process at a time. + :param scale_score: The raw predictions will be transformed using a Sigmoid activation function in case the model + only predicts a single label. For multi-label predictions, no scaling is applied. Set this + to False if you do not want any scaling of the raw predictions. """ super().__init__() @@ -76,6 +80,15 @@ class SentenceTransformersRanker(BaseRanker): ) self.transformer_model.eval() + # we use sigmoid activation function to scale the score in case there is only a single label + # we do not apply any scaling when scale_score is set to False + num_labels = self.transformer_model.num_labels + self.activation_function: torch.nn.Module + if num_labels == 1 and scale_score: + self.activation_function = torch.nn.Sigmoid() + else: + self.activation_function = torch.nn.Identity() + if len(self.devices) > 1: self.model = DataParallel(self.transformer_model, device_ids=self.devices) @@ -119,9 +132,31 @@ class SentenceTransformersRanker(BaseRanker): reverse=True, ) - # rank documents according to scores - sorted_documents = [doc for _, doc in sorted_scores_and_documents] - return sorted_documents[:top_k] + # add normalized scores to documents + sorted_documents = self._add_scores_to_documents(sorted_scores_and_documents[:top_k], logits_dim) + + return sorted_documents + + def _add_scores_to_documents( + self, sorted_scores_and_documents: List[Tuple[Any, Document]], logits_dim: int + ) -> List[Document]: + """ + Normalize and add scores to retrieved result documents. + + :param sorted_scores_and_documents: List of score, Document Tuples. + :param logits_dim: Dimensionality of the returned scores. + """ + sorted_documents = [] + for raw_score, doc in sorted_scores_and_documents: + if logits_dim >= 2: + score = self.activation_function(raw_score)[-1] + else: + score = self.activation_function(raw_score)[0] + + doc.score = score.detach().cpu().numpy().tolist() + sorted_documents.append(doc) + + return sorted_documents def predict_batch( self, @@ -185,9 +220,11 @@ class SentenceTransformersRanker(BaseRanker): reverse=True, ) - # rank documents according to scores - sorted_documents = [doc for _, doc in sorted_scores_and_documents if isinstance(doc, Document)] - return sorted_documents[:top_k] + # is this step needed? + sorted_documents = [(score, doc) for score, doc in sorted_scores_and_documents if isinstance(doc, Document)] + sorted_documents_with_scores = self._add_scores_to_documents(sorted_documents[:top_k], logits_dim) + + return sorted_documents_with_scores else: # Group predictions together grouped_predictions = [] @@ -209,8 +246,12 @@ class SentenceTransformersRanker(BaseRanker): ) # rank documents according to scores - sorted_documents = [doc for _, doc in sorted_scores_and_documents if isinstance(doc, Document)][:top_k] - result.append(sorted_documents) + sorted_documents = [ + (score, doc) for score, doc in sorted_scores_and_documents if isinstance(doc, Document) + ] + sorted_documents_with_scores = self._add_scores_to_documents(sorted_documents[:top_k], logits_dim) + + result.append(sorted_documents_with_scores) return result diff --git a/test/nodes/test_ranker.py b/test/nodes/test_ranker.py index c4836b3be..d7b8e9a19 100644 --- a/test/nodes/test_ranker.py +++ b/test/nodes/test_ranker.py @@ -1,4 +1,5 @@ import pytest +import math from haystack.errors import HaystackError from haystack.schema import Document @@ -173,3 +174,54 @@ def test_ranker_two_logits(ranker_two_logits): ] results = ranker_two_logits.predict(query=query, documents=docs) assert results[0] == docs[4] + + +def test_ranker_returns_normalized_score(ranker): + query = "What is the most important building in King's Landing that has a religious background?" + + docs = [ + Document( + content="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""", + meta={"name": "0"}, + id="1", + ) + ] + + results = ranker.predict(query=query, documents=docs) + score = results[0].score + precomputed_score = 5.8796231e-05 + assert math.isclose(precomputed_score, score, rel_tol=0.01) + + +def test_ranker_returns_raw_score_when_no_scaling(): + ranker = SentenceTransformersRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-12-v2", scale_score=False) + query = "What is the most important building in King's Landing that has a religious background?" + + docs = [ + Document( + content="""Aaron Aaron ( or ; ""Ahärôn"") is a prophet, high priest, and the brother of Moses in the Abrahamic religions. Knowledge of Aaron, along with his brother Moses, comes exclusively from religious texts, such as the Bible and Quran. The Hebrew Bible relates that, unlike Moses, who grew up in the Egyptian royal court, Aaron and his elder sister Miriam remained with their kinsmen in the eastern border-land of Egypt (Goshen). When Moses first confronted the Egyptian king about the Israelites, Aaron served as his brother's spokesman (""prophet"") to the Pharaoh. Part of the Law (Torah) that Moses received from""", + meta={"name": "0"}, + id="1", + ) + ] + + results = ranker.predict(query=query, documents=docs) + score = results[0].score + precomputed_score = -9.744687 + assert math.isclose(precomputed_score, score, rel_tol=0.001) + + +def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits): + query = "Welches ist das wichtigste Gebäude in Königsmund, das einen religiösen Hintergrund hat?" + docs = [ + Document( + content="""Aaron Aaron (oder ; "Ahärôn") ist ein Prophet, Hohepriester und der Bruder von Moses in den abrahamitischen Religionen. Aaron ist ebenso wie sein Bruder Moses ausschließlich aus religiösen Texten wie der Bibel und dem Koran bekannt. Die hebräische Bibel berichtet, dass Aaron und seine ältere Schwester Mirjam im Gegensatz zu Mose, der am ägyptischen Königshof aufwuchs, bei ihren Verwandten im östlichen Grenzland Ägyptens (Goschen) blieben. Als Mose den ägyptischen König zum ersten Mal mit den Israeliten konfrontierte, fungierte Aaron als Sprecher ("Prophet") seines Bruders gegenüber dem Pharao. Ein Teil des Gesetzes (Tora), das Mose von""", + meta={"name": "0"}, + id="1", + ) + ] + + results = ranker_two_logits.predict(query=query, documents=docs) + score = results[0].score + precomputed_score = -3.61354 + assert math.isclose(precomputed_score, score, rel_tol=0.001)