From d51be9edac914e281502075c5267d3e4b74c42fc Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Thu, 12 Oct 2023 13:52:01 +0200 Subject: [PATCH] Add top_k to SimilarityRanker (#6036) --- .../preview/components/rankers/similarity.py | 24 ++++++++--- .../components/rankers/test_similarity.py | 43 +++++++++++++++++-- 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/haystack/preview/components/rankers/similarity.py b/haystack/preview/components/rankers/similarity.py index 7abb37095..66f620294 100644 --- a/haystack/preview/components/rankers/similarity.py +++ b/haystack/preview/components/rankers/similarity.py @@ -1,6 +1,6 @@ import logging from pathlib import Path -from typing import List, Union, Dict, Any +from typing import List, Union, Dict, Any, Optional from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict from haystack.preview.lazy_imports import LazyImport @@ -34,17 +34,24 @@ class SimilarityRanker: """ def __init__( - self, model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cpu" + self, + model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", + top_k: int = 10, + device: str = "cpu", ): """ Creates an instance of SimilarityRanker. :param model_name_or_path: Path to a pre-trained sentence-transformers model. + :param top_k: The maximum number of documents to return per query. :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. """ torch_and_transformers_import.check() self.model_name_or_path = model_name_or_path + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + self.top_k = top_k self.device = device self.model = None self.tokenizer = None @@ -63,7 +70,7 @@ class SimilarityRanker: """ Serialize this component to a dictionary. """ - return default_to_dict(self, device=self.device, model_name_or_path=self.model_name_or_path) + return default_to_dict(self, top_k=self.top_k, device=self.device, model_name_or_path=self.model_name_or_path) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker": @@ -73,17 +80,24 @@ class SimilarityRanker: return default_from_dict(cls, data) @component.output_types(documents=List[Document]) - def run(self, query: str, documents: List[Document]): + def run(self, query: str, documents: List[Document], top_k: Optional[int] = None): """ Returns a list of documents ranked by their similarity to the given query :param query: Query string. :param documents: List of Documents. + :param top_k: The maximum number of documents to return. :return: List of Documents sorted by (desc.) similarity with the query. """ if not documents: return {"documents": []} + if top_k is None: + top_k = self.top_k + + elif top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") + # If a model path is provided but the model isn't loaded if self.model_name_or_path and not self.model: raise ComponentError( @@ -105,4 +119,4 @@ class SimilarityRanker: i = sorted_index_tensor.item() documents[i].score = similarity_scores[i].item() ranked_docs.append(documents[i]) - return {"documents": ranked_docs} + return {"documents": ranked_docs[:top_k]} diff --git a/test/preview/components/rankers/test_similarity.py b/test/preview/components/rankers/test_similarity.py index 5ddb3b18d..cc2486a2c 100644 --- a/test/preview/components/rankers/test_similarity.py +++ b/test/preview/components/rankers/test_similarity.py @@ -11,7 +11,11 @@ class TestSimilarityRanker: data = component.to_dict() assert data == { "type": "SimilarityRanker", - "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + "init_parameters": { + "device": "cpu", + "top_k": 10, + "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", + }, } @pytest.mark.unit @@ -20,14 +24,22 @@ class TestSimilarityRanker: data = component.to_dict() assert data == { "type": "SimilarityRanker", - "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + "init_parameters": { + "device": "cpu", + "top_k": 10, + "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", + }, } @pytest.mark.integration def test_from_dict(self): data = { "type": "SimilarityRanker", - "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + "init_parameters": { + "device": "cpu", + "top_k": 10, + "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", + }, } component = SimilarityRanker.from_dict(data) assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2" @@ -72,3 +84,28 @@ class TestSimilarityRanker: with pytest.raises(ComponentError): sampler.run(query="query", documents=[Document(text="document")]) + + @pytest.mark.integration + @pytest.mark.parametrize( + "query,docs_before_texts,expected_first_text", + [ + ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), + ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), + ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), + ], + ) + def test_run_top_k(self, query, docs_before_texts, expected_first_text): + """ + Test if the component ranks documents correctly with a custom top_k. + """ + ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) + ranker.warm_up() + docs_before = [Document(text=text) for text in docs_before_texts] + output = ranker.run(query=query, documents=docs_before) + docs_after = output["documents"] + + assert len(docs_after) == 2 + assert docs_after[0].text == expected_first_text + + sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) + assert [doc.score for doc in docs_after] == sorted_scores