diff --git a/docs/pydoc/config/ranker.yml b/docs/pydoc/config/ranker.yml index ff33cc3cf..e0776e7f0 100644 --- a/docs/pydoc/config/ranker.yml +++ b/docs/pydoc/config/ranker.yml @@ -1,7 +1,7 @@ loaders: - type: python search_path: [../../../haystack/nodes/ranker] - modules: ["base", "sentence_transformers", "recentness_ranker"] + modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"] ignore_when_discovered: ["__init__"] processors: - type: filter @@ -24,4 +24,3 @@ renderer: add_method_class_prefix: true add_member_class_prefix: false filename: ranker_api.md - diff --git a/examples/web_lfqa_improved.py b/examples/web_lfqa_improved.py new file mode 100644 index 000000000..c46852617 --- /dev/null +++ b/examples/web_lfqa_improved.py @@ -0,0 +1,49 @@ +import logging +import os + +from haystack import Pipeline +from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger +from haystack.nodes.ranker.diversity import DiversityRanker +from haystack.nodes.retriever.web import WebRetriever + +search_key = os.environ.get("SERPERDEV_API_KEY") +if not search_key: + raise ValueError("Please set the SERPERDEV_API_KEY environment variable") + +openai_key = os.environ.get("OPENAI_API_KEY") +if not openai_key: + raise ValueError("Please set the OPENAI_API_KEY environment variable") + +prompt_text = """ +Synthesize a comprehensive answer from the following most relevant paragraphs and the given question. +Provide a clear and concise response that summarizes the key points and information presented in the paragraphs. +Your answer should be in your own words and be no longer than 50 words. +\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer: +""" + +prompt_node = PromptNode( + "gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256 +) + +web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25) + +sampler = TopPSampler(top_p=0.95) +ranker = DiversityRanker() +merger = DocumentMerger(separator="\n\n") + +pipeline = Pipeline() +pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"]) +pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"]) +pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"]) +pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"]) +pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"]) + +logger = logging.getLogger("boilerpy3") +logger.setLevel(logging.CRITICAL) + +questions = ["What are the reasons for long-standing animosities between Russia and Poland?"] + +for q in questions: + print(f"Question: {q}") + response = pipeline.run(query=q) + print(f"Answer: {response['results'][0]}") diff --git a/haystack/nodes/ranker/diversity.py b/haystack/nodes/ranker/diversity.py new file mode 100644 index 000000000..893c751d1 --- /dev/null +++ b/haystack/nodes/ranker/diversity.py @@ -0,0 +1,149 @@ +import logging +from pathlib import Path +from typing import List, Literal, Optional, Union + +from haystack.nodes import BaseRanker +from haystack.schema import Document +from haystack.lazy_imports import LazyImport + +logger = logging.getLogger(__name__) + +with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import: + import torch + from sentence_transformers import SentenceTransformer + from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports + + +class DiversityRanker(BaseRanker): + """ + Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity + of the documents. + """ + + def __init__( + self, + model_name_or_path: Union[str, Path] = "all-MiniLM-L6-v2", + use_gpu: Optional[bool] = True, + devices: Optional[List[Union[str, "torch.device"]]] = None, + similarity: Literal["dot_product", "cosine"] = "dot_product", + ): + """ + Initialize a DiversityRanker. + + :param model_name_or_path: Path to a pretrained sentence-transformers model. + :param use_gpu: Whether to use GPU (if available). If no GPUs are available, it falls back on a CPU. + :param devices: List of torch devices (for example, cuda:0, cpu, mps) to limit inference to specific devices. + :param similarity: Whether to use dot product or cosine similarity. Can be set to "dot_product" (default) or "cosine". + """ + torch_and_transformers_import.check() + super().__init__() + self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True) + self.model = SentenceTransformer(model_name_or_path, device=str(self.devices[0])) + self.similarity = similarity + + def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]: + """ + Rank the documents based on their diversity and return the top_k documents. + + :param query: The query. + :param documents: A list of Document objects that should be ranked. + :param top_k: The maximum number of documents to return. + + :return: A list of top_k documents ranked based on diversity. + """ + if query is None or len(query) == 0: + raise ValueError("Query is empty") + if documents is None or len(documents) == 0: + raise ValueError("No documents to choose from") + + diversity_sorted = self.greedy_diversity_order(query=query, documents=documents) + return diversity_sorted[:top_k] + + def greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]: + """ + Orders the given list of documents to maximize diversity. The algorithm first calculates embeddings for + each document and the query. It starts by selecting the document that is semantically closest to the query. + Then, for each remaining document, it selects the one that, on average, is least similar to the already + selected documents. This process continues until all documents are selected, resulting in a list where + each subsequent document contributes the most to the overall diversity of the selected set. + + :param query: The search query. + :param documents: The list of Document objects to be ranked. + + :return: A list of documents ordered to maximize diversity. + """ + + # Calculate embeddings + doc_embeddings: torch.Tensor = self.model.encode([d.content for d in documents], convert_to_tensor=True) + query_embedding: torch.Tensor = self.model.encode([query], convert_to_tensor=True) + + if self.similarity == "dot_product": + doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1) + query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1) + + n = len(documents) + selected: List[int] = [] + + # Compute the similarity vector between the query and documents + query_doc_sim: torch.Tensor = query_embedding @ doc_embeddings.T + + # Start with the document with the highest similarity to the query + selected.append(int(torch.argmax(query_doc_sim).item())) + + selected_sum = doc_embeddings[selected[0]] / n + + while len(selected) < n: + # Compute mean of dot products of all selected documents and all other documents + similarities = selected_sum @ doc_embeddings.T + # Mask documents that are already selected + similarities[selected] = torch.inf + # Select the document with the lowest total similarity score + index_unselected = int(torch.argmin(similarities).item()) + + selected.append(index_unselected) + # It's enough just to add to the selected vectors because dot product is distributive + # It's divided by n for numerical stability + selected_sum += doc_embeddings[index_unselected] / n + + ranked_docs: List[Document] = [documents[i] for i in selected] + + return ranked_docs + + def predict_batch( + self, + queries: List[str], + documents: Union[List[Document], List[List[Document]]], + top_k: Optional[float] = None, + batch_size: Optional[int] = None, + ) -> Union[List[Document], List[List[Document]]]: + """ + Rank the documents based on their diversity and return the top_k documents. + + :param queries: The queries. + :param documents: A list (or a list of lists) of Document objects that should be ranked. + :param top_k: The maximum number of documents to return. + :param batch_size: The number of documents to process in one batch. + + :return: A list (or a list of lists) of top_k documents ranked based on diversity. + """ + if queries is None or len(queries) == 0: + raise ValueError("No queries to choose from") + if documents is None or len(documents) == 0: + raise ValueError("No documents to choose from") + if len(documents) > 0 and isinstance(documents[0], Document): + # Docs case 1: single list of Documents -> rerank single list of Documents based on single query + if len(queries) != 1: + raise ValueError("Number of queries must be 1 if a single list of Documents is provided.") + return self.predict(query=queries[0], documents=documents, top_k=top_k) # type: ignore + else: + # Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query + # If queries contains a single query, apply it to each list of Documents + if len(queries) == 1: + queries = queries * len(documents) + if len(queries) != len(documents): + raise ValueError("Number of queries must be equal to number of provided Document lists.") + + results = [] + for query, cur_docs in zip(queries, documents): + results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore + return results diff --git a/releasenotes/notes/add-most-diverse-ranker-21cf310be4554551.yaml b/releasenotes/notes/add-most-diverse-ranker-21cf310be4554551.yaml new file mode 100644 index 000000000..8f475dd60 --- /dev/null +++ b/releasenotes/notes/add-most-diverse-ranker-21cf310be4554551.yaml @@ -0,0 +1,17 @@ +--- +prelude: > + We're introducing a new ranker to Haystack - DiversityRanker. This + ranker aims to maximize the overall diversity of the given documents. + It leverages sentence-transformer models to calculate semantic embeddings + for each document. It orders documents so that the next one, on average, + is least similar to the already selected documents. Such ranking results in a + list where each subsequent document contributes the most to the overall + diversity of the selected document set. +features: + - | + The DiversityRanker can be used like other rankers in Haystack and + it can be particularly helpful in cases where you have highly relevant + yet similar sets of documents. By ensuring a diversity of documents, + this new ranker facilitates a more comprehensive utilization of the + documents and, particularly in RAG pipelines, potentially contributes + to more accurate and rich model responses. diff --git a/test/nodes/test_diversity_ranker.py b/test/nodes/test_diversity_ranker.py new file mode 100644 index 000000000..857f8a904 --- /dev/null +++ b/test/nodes/test_diversity_ranker.py @@ -0,0 +1,233 @@ +from typing import List + +import pytest + +from haystack import Document +from haystack.nodes.ranker.diversity import DiversityRanker + + +# Tests that predict method returns a list of Document objects +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_returns_list_of_documents(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "test query" + documents = [Document(content="doc1"), Document(content="doc2")] + result = ranker.predict(query=query, documents=documents) + assert isinstance(result, list) + assert len(result) == 2 + assert all(isinstance(doc, Document) for doc in result) + + +# Tests that predict method returns the correct number of documents +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_returns_correct_number_of_documents(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "test query" + documents = [Document(content="doc1"), Document(content="doc2")] + result = ranker.predict(query=query, documents=documents, top_k=1) + assert len(result) == 1 + + +# Tests that predict method returns documents in the correct order +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_returns_documents_in_correct_order(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "city" + documents = [ + Document("France"), + Document("Germany"), + Document("Eiffel Tower"), + Document("Berlin"), + Document("bananas"), + Document("Silicon Valley"), + Document("Brandenburg Gate"), + ] + result = ranker.predict(query=query, documents=documents) + expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany" + assert ", ".join([doc.content for doc in result]) == expected_order + + +# Tests that predict_batch method returns a list of lists of Document objects +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_batch_returns_list_of_lists_of_documents(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + queries = ["test query 1", "test query 2"] + documents = [ + [Document(content="doc1"), Document(content="doc2")], + [Document(content="doc3"), Document(content="doc4")], + ] + result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents) + assert isinstance(result, list) + assert all(isinstance(docs, list) for docs in result) + assert all(isinstance(doc, Document) for docs in result for doc in docs) + + +# Tests that predict_batch method returns the correct number of documents +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_batch_returns_correct_number_of_documents(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + queries = ["test query 1", "test query 2"] + documents = [ + [Document(content="doc1"), Document(content="doc2")], + [Document(content="doc3"), Document(content="doc4")], + ] + result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1) + assert len(result) == 2 + assert len(result[0]) == 1 + assert len(result[1]) == 1 + + +# Tests that predict_batch method returns documents in the correct order +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_batch_returns_documents_in_correct_order(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + queries = ["Berlin", "Paris"] + documents = [ + [Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")], + [Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")], + ] + result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents) + assert len(result) == 2 + + # check the correct most diverse order are in batches + expected_order_0 = "Germany, agriculture, Munich" + expected_order_1 = "France, Space exploration, Eiffel Tower" + assert ", ".join([doc.content for doc in result[0]]) == expected_order_0 + assert ", ".join([doc.content for doc in result[1]]) == expected_order_1 + + +# Tests that predict method returns the correct number of documents for a single document +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_single_document_corner_case(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "test" + documents = [Document(content="doc1")] + result = ranker.predict(query=query, documents=documents) + assert len(result) == 1 + + +# Tests that predict method raises ValueError if query is empty +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_raises_value_error_if_query_is_empty(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "" + documents = [Document(content="doc1"), Document(content="doc2")] + with pytest.raises(ValueError): + ranker.predict(query=query, documents=documents) + + +# Tests that predict method raises ValueError if documents is empty +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_raises_value_error_if_documents_is_empty(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "test query" + documents = [] + with pytest.raises(ValueError): + ranker.predict(query=query, documents=documents) + + +# Tests that predict_batch method raises ValueError if queries is empty +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + queries = [] + documents = [ + [Document(content="doc1"), Document(content="doc2")], + [Document(content="doc3"), Document(content="doc4")], + ] + with pytest.raises(ValueError): + ranker.predict_batch(queries=queries, documents=documents) + + +# Tests that predict_batch method raises ValueError if documents is empty +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + queries = ["test query 1", "test query 2"] + documents = [] + with pytest.raises(ValueError): + ranker.predict_batch(queries=queries, documents=documents) + + +@pytest.mark.integration +@pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) +def test_predict_real_world_use_case(similarity: str): + ranker = DiversityRanker(similarity=similarity) # type: ignore + query = "What are the reasons for long-standing animosities between Russia and Poland?" + + doc1 = Document( + "One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , " + "Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by " + "that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland " + "accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was " + "Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into " + "the Catholic and Orthodox branches separating the Poles from the Eastern Slavs." + ) + + doc2 = Document( + "Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the " + "Polish–Russian border has mostly been replaced by borders with the respective countries, but there still " + "is a 210 km long border between Poland and the Kaliningrad Oblast" + ) + + doc3 = Document( + "As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr " + "Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of " + "the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the " + "Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices " + "should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when " + "PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG." + ) + + doc4 = Document( + "Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly " + "accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in " + "2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was " + "not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a " + "notorious Sobibor extermination camp." + ) + + doc5 = Document( + "President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish–Russian " + "relations begin with the fall of communism – 1989 in Poland ( Solidarity and the Polish Round Table " + "Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after " + "the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly " + "independent states , including the Russian Federation . Relations between modern Poland and Russia suffer " + "from constant ups and downs." + ) + + doc6 = Document( + "Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections " + "in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and " + "finally the formal dissolution of the Warsaw Pact." + ) + + doc7 = Document( + "Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of " + "the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish " + "relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase " + "suspicions toward Russia in Poland." + ) + + doc8 = Document( + "Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and " + "ceased completely after the fall of the communist government in Poland in late 1989, although the " + "Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military " + "presence allowed the Soviet Union to heavily influence Polish politics." + ) + + documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8] + result = ranker.predict(query=query, documents=documents) + expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8] + assert result == expected_order