feat: Add DiversityRanker (#5398)

* Introduce DiversityRanker

* improve most_diverse_order speed

* Compute mean for numerical stability

* Add release note

* Add cosine similarity 

* Test both dot product and cosine similarity

* Add pydocs hook

---------

Co-authored-by: Michel Bartels <login@michelbartels.com>
This commit is contained in:
Vladimir Blagojevic 2023-08-01 12:48:34 +02:00 committed by GitHub
parent 8c017ccc32
commit 540d0fad97
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 449 additions and 2 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/nodes/ranker]
modules: ["base", "sentence_transformers", "recentness_ranker"]
modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"]
ignore_when_discovered: ["__init__"]
processors:
- type: filter
@ -24,4 +24,3 @@ renderer:
add_method_class_prefix: true
add_member_class_prefix: false
filename: ranker_api.md

View File

@ -0,0 +1,49 @@
import logging
import os
from haystack import Pipeline
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger
from haystack.nodes.ranker.diversity import DiversityRanker
from haystack.nodes.retriever.web import WebRetriever
search_key = os.environ.get("SERPERDEV_API_KEY")
if not search_key:
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")
openai_key = os.environ.get("OPENAI_API_KEY")
if not openai_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable")
prompt_text = """
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question.
Provide a clear and concise response that summarizes the key points and information presented in the paragraphs.
Your answer should be in your own words and be no longer than 50 words.
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer:
"""
prompt_node = PromptNode(
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256
)
web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25)
sampler = TopPSampler(top_p=0.95)
ranker = DiversityRanker()
merger = DocumentMerger(separator="\n\n")
pipeline = Pipeline()
pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"])
pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"])
pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"])
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"])
logger = logging.getLogger("boilerpy3")
logger.setLevel(logging.CRITICAL)
questions = ["What are the reasons for long-standing animosities between Russia and Poland?"]
for q in questions:
print(f"Question: {q}")
response = pipeline.run(query=q)
print(f"Answer: {response['results'][0]}")

View File

@ -0,0 +1,149 @@
import logging
from pathlib import Path
from typing import List, Literal, Optional, Union
from haystack.nodes import BaseRanker
from haystack.schema import Document
from haystack.lazy_imports import LazyImport
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import:
import torch
from sentence_transformers import SentenceTransformer
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
class DiversityRanker(BaseRanker):
"""
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
"""
def __init__(
self,
model_name_or_path: Union[str, Path] = "all-MiniLM-L6-v2",
use_gpu: Optional[bool] = True,
devices: Optional[List[Union[str, "torch.device"]]] = None,
similarity: Literal["dot_product", "cosine"] = "dot_product",
):
"""
Initialize a DiversityRanker.
:param model_name_or_path: Path to a pretrained sentence-transformers model.
:param use_gpu: Whether to use GPU (if available). If no GPUs are available, it falls back on a CPU.
:param devices: List of torch devices (for example, cuda:0, cpu, mps) to limit inference to specific devices.
:param similarity: Whether to use dot product or cosine similarity. Can be set to "dot_product" (default) or "cosine".
"""
torch_and_transformers_import.check()
super().__init__()
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)
self.model = SentenceTransformer(model_name_or_path, device=str(self.devices[0]))
self.similarity = similarity
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
"""
Rank the documents based on their diversity and return the top_k documents.
:param query: The query.
:param documents: A list of Document objects that should be ranked.
:param top_k: The maximum number of documents to return.
:return: A list of top_k documents ranked based on diversity.
"""
if query is None or len(query) == 0:
raise ValueError("Query is empty")
if documents is None or len(documents) == 0:
raise ValueError("No documents to choose from")
diversity_sorted = self.greedy_diversity_order(query=query, documents=documents)
return diversity_sorted[:top_k]
def greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
"""
Orders the given list of documents to maximize diversity. The algorithm first calculates embeddings for
each document and the query. It starts by selecting the document that is semantically closest to the query.
Then, for each remaining document, it selects the one that, on average, is least similar to the already
selected documents. This process continues until all documents are selected, resulting in a list where
each subsequent document contributes the most to the overall diversity of the selected set.
:param query: The search query.
:param documents: The list of Document objects to be ranked.
:return: A list of documents ordered to maximize diversity.
"""
# Calculate embeddings
doc_embeddings: torch.Tensor = self.model.encode([d.content for d in documents], convert_to_tensor=True)
query_embedding: torch.Tensor = self.model.encode([query], convert_to_tensor=True)
if self.similarity == "dot_product":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
n = len(documents)
selected: List[int] = []
# Compute the similarity vector between the query and documents
query_doc_sim: torch.Tensor = query_embedding @ doc_embeddings.T
# Start with the document with the highest similarity to the query
selected.append(int(torch.argmax(query_doc_sim).item()))
selected_sum = doc_embeddings[selected[0]] / n
while len(selected) < n:
# Compute mean of dot products of all selected documents and all other documents
similarities = selected_sum @ doc_embeddings.T
# Mask documents that are already selected
similarities[selected] = torch.inf
# Select the document with the lowest total similarity score
index_unselected = int(torch.argmin(similarities).item())
selected.append(index_unselected)
# It's enough just to add to the selected vectors because dot product is distributive
# It's divided by n for numerical stability
selected_sum += doc_embeddings[index_unselected] / n
ranked_docs: List[Document] = [documents[i] for i in selected]
return ranked_docs
def predict_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[float] = None,
batch_size: Optional[int] = None,
) -> Union[List[Document], List[List[Document]]]:
"""
Rank the documents based on their diversity and return the top_k documents.
:param queries: The queries.
:param documents: A list (or a list of lists) of Document objects that should be ranked.
:param top_k: The maximum number of documents to return.
:param batch_size: The number of documents to process in one batch.
:return: A list (or a list of lists) of top_k documents ranked based on diversity.
"""
if queries is None or len(queries) == 0:
raise ValueError("No queries to choose from")
if documents is None or len(documents) == 0:
raise ValueError("No documents to choose from")
if len(documents) > 0 and isinstance(documents[0], Document):
# Docs case 1: single list of Documents -> rerank single list of Documents based on single query
if len(queries) != 1:
raise ValueError("Number of queries must be 1 if a single list of Documents is provided.")
return self.predict(query=queries[0], documents=documents, top_k=top_k) # type: ignore
else:
# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
# If queries contains a single query, apply it to each list of Documents
if len(queries) == 1:
queries = queries * len(documents)
if len(queries) != len(documents):
raise ValueError("Number of queries must be equal to number of provided Document lists.")
results = []
for query, cur_docs in zip(queries, documents):
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
return results

View File

@ -0,0 +1,17 @@
---
prelude: >
We're introducing a new ranker to Haystack - DiversityRanker. This
ranker aims to maximize the overall diversity of the given documents.
It leverages sentence-transformer models to calculate semantic embeddings
for each document. It orders documents so that the next one, on average,
is least similar to the already selected documents. Such ranking results in a
list where each subsequent document contributes the most to the overall
diversity of the selected document set.
features:
- |
The DiversityRanker can be used like other rankers in Haystack and
it can be particularly helpful in cases where you have highly relevant
yet similar sets of documents. By ensuring a diversity of documents,
this new ranker facilitates a more comprehensive utilization of the
documents and, particularly in RAG pipelines, potentially contributes
to more accurate and rich model responses.

View File

@ -0,0 +1,233 @@
from typing import List
import pytest
from haystack import Document
from haystack.nodes.ranker.diversity import DiversityRanker
# Tests that predict method returns a list of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_list_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(doc, Document) for doc in result)
# Tests that predict method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents, top_k=1)
assert len(result) == 1
# Tests that predict method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "city"
documents = [
Document("France"),
Document("Germany"),
Document("Eiffel Tower"),
Document("Berlin"),
Document("bananas"),
Document("Silicon Valley"),
Document("Brandenburg Gate"),
]
result = ranker.predict(query=query, documents=documents)
expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
assert ", ".join([doc.content for doc in result]) == expected_order
# Tests that predict_batch method returns a list of lists of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_list_of_lists_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert isinstance(result, list)
assert all(isinstance(docs, list) for docs in result)
assert all(isinstance(doc, Document) for docs in result for doc in docs)
# Tests that predict_batch method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1)
assert len(result) == 2
assert len(result[0]) == 1
assert len(result[1]) == 1
# Tests that predict_batch method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["Berlin", "Paris"]
documents = [
[Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")],
[Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert len(result) == 2
# check the correct most diverse order are in batches
expected_order_0 = "Germany, agriculture, Munich"
expected_order_1 = "France, Space exploration, Eiffel Tower"
assert ", ".join([doc.content for doc in result[0]]) == expected_order_0
assert ", ".join([doc.content for doc in result[1]]) == expected_order_1
# Tests that predict method returns the correct number of documents for a single document
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_single_document_corner_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test"
documents = [Document(content="doc1")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 1
# Tests that predict method raises ValueError if query is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_query_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = ""
documents = [Document(content="doc1"), Document(content="doc2")]
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = []
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict_batch method raises ValueError if queries is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = []
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
# Tests that predict_batch method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = []
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_real_world_use_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "What are the reasons for long-standing animosities between Russia and Poland?"
doc1 = Document(
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
)
doc2 = Document(
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
"PolishRussian border has mostly been replaced by borders with the respective countries, but there still "
"is a 210 km long border between Poland and the Kaliningrad Oblast"
)
doc3 = Document(
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
"Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
)
doc4 = Document(
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
"notorious Sobibor extermination camp."
)
doc5 = Document(
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern PolishRussian "
"relations begin with the fall of communism 1989 in Poland ( Solidarity and the Polish Round Table "
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
"from constant ups and downs."
)
doc6 = Document(
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and "
"finally the formal dissolution of the Warsaw Pact."
)
doc7 = Document(
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
"suspicions toward Russia in Poland."
)
doc8 = Document(
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and "
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
"presence allowed the Soviet Union to heavily influence Polish politics."
)
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
result = ranker.predict(query=query, documents=documents)
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
assert result == expected_order