diff --git a/haystack/preview/components/rankers/__init__.py b/haystack/preview/components/rankers/__init__.py new file mode 100644 index 000000000..27337481e --- /dev/null +++ b/haystack/preview/components/rankers/__init__.py @@ -0,0 +1,3 @@ +from haystack.preview.components.rankers.similarity import SimilarityRanker + +__all__ = ["SimilarityRanker"] diff --git a/haystack/preview/components/rankers/similarity.py b/haystack/preview/components/rankers/similarity.py new file mode 100644 index 000000000..7abb37095 --- /dev/null +++ b/haystack/preview/components/rankers/similarity.py @@ -0,0 +1,108 @@ +import logging +from pathlib import Path +from typing import List, Union, Dict, Any + +from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict +from haystack.preview.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + + +with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.32.1'") as torch_and_transformers_import: + import torch + from transformers import AutoModelForSequenceClassification, AutoTokenizer + + +@component +class SimilarityRanker: + """ + Ranks documents based on query similarity. + + Usage example: + ``` + from haystack.preview import Document + from haystack.preview.components.rankers import SimilarityRanker + + sampler = SimilarityRanker() + docs = [Document(text="Paris"), Document(text="Berlin")] + query = "City in Germany" + output = sampler.run(query=query, documents=docs) + docs = output["documents"] + assert len(docs) == 2 + assert docs[0].text == "Berlin" + ``` + """ + + def __init__( + self, model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cpu" + ): + """ + Creates an instance of SimilarityRanker. + + :param model_name_or_path: Path to a pre-trained sentence-transformers model. + :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. + """ + torch_and_transformers_import.check() + + self.model_name_or_path = model_name_or_path + self.device = device + self.model = None + self.tokenizer = None + + def warm_up(self): + """ + Warm up the model and tokenizer used in scoring the documents. + """ + if self.model_name_or_path and not self.model: + self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name_or_path) + self.model = self.model.to(self.device) + self.model.eval() + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) + + def to_dict(self) -> Dict[str, Any]: + """ + Serialize this component to a dictionary. + """ + return default_to_dict(self, device=self.device, model_name_or_path=self.model_name_or_path) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker": + """ + Deserialize this component from a dictionary. + """ + return default_from_dict(cls, data) + + @component.output_types(documents=List[Document]) + def run(self, query: str, documents: List[Document]): + """ + Returns a list of documents ranked by their similarity to the given query + + :param query: Query string. + :param documents: List of Documents. + :return: List of Documents sorted by (desc.) similarity with the query. + """ + if not documents: + return {"documents": []} + + # If a model path is provided but the model isn't loaded + if self.model_name_or_path and not self.model: + raise ComponentError( + f"The component {self.__class__.__name__} not warmed up. Run 'warm_up()' before calling 'run()'." + ) + + query_doc_pairs = [[query, doc.text] for doc in documents] + features = self.tokenizer( + query_doc_pairs, padding=True, truncation=True, return_tensors="pt" + ).to( # type: ignore + self.device + ) + with torch.inference_mode(): + similarity_scores = self.model(**features).logits.squeeze() # type: ignore + + _, sorted_indices = torch.sort(similarity_scores, descending=True) + ranked_docs = [] + for sorted_index_tensor in sorted_indices: + i = sorted_index_tensor.item() + documents[i].score = similarity_scores[i].item() + ranked_docs.append(documents[i]) + return {"documents": ranked_docs} diff --git a/releasenotes/notes/add-similarity-ranker-401bf595cea7318a.yaml b/releasenotes/notes/add-similarity-ranker-401bf595cea7318a.yaml new file mode 100644 index 000000000..a0d3217a6 --- /dev/null +++ b/releasenotes/notes/add-similarity-ranker-401bf595cea7318a.yaml @@ -0,0 +1,4 @@ +--- +preview: + - | + Adds SimilarityRanker, a component that ranks a list of Documents based on their similarity to the query. diff --git a/test/preview/components/rankers/test_similarity.py b/test/preview/components/rankers/test_similarity.py new file mode 100644 index 000000000..5ddb3b18d --- /dev/null +++ b/test/preview/components/rankers/test_similarity.py @@ -0,0 +1,74 @@ +import pytest + +from haystack.preview import Document, ComponentError +from haystack.preview.components.rankers.similarity import SimilarityRanker + + +class TestSimilarityRanker: + @pytest.mark.unit + def test_to_dict(self): + component = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") + data = component.to_dict() + assert data == { + "type": "SimilarityRanker", + "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + } + + @pytest.mark.unit + def test_to_dict_with_custom_init_parameters(self): + component = SimilarityRanker() + data = component.to_dict() + assert data == { + "type": "SimilarityRanker", + "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + } + + @pytest.mark.integration + def test_from_dict(self): + data = { + "type": "SimilarityRanker", + "init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, + } + component = SimilarityRanker.from_dict(data) + assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2" + + @pytest.mark.integration + @pytest.mark.parametrize( + "query,docs_before_texts,expected_first_text", + [ + ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), + ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), + ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), + ], + ) + def test_run(self, query, docs_before_texts, expected_first_text): + """ + Test if the component ranks documents correctly. + """ + ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") + ranker.warm_up() + docs_before = [Document(text=text) for text in docs_before_texts] + output = ranker.run(query=query, documents=docs_before) + docs_after = output["documents"] + + assert len(docs_after) == 3 + assert docs_after[0].text == expected_first_text + + sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) + assert [doc.score for doc in docs_after] == sorted_scores + + # Returns an empty list if no documents are provided + @pytest.mark.integration + def test_returns_empty_list_if_no_documents_are_provided(self): + sampler = SimilarityRanker() + sampler.warm_up() + output = sampler.run(query="City in Germany", documents=[]) + assert output["documents"] == [] + + # Raises ComponentError if model is not warmed up + @pytest.mark.integration + def test_raises_component_error_if_model_not_warmed_up(self): + sampler = SimilarityRanker() + + with pytest.raises(ComponentError): + sampler.run(query="query", documents=[Document(text="document")])