mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 05:58:57 +00:00
feat: Add DiversityRanker (#5398)
* Introduce DiversityRanker * improve most_diverse_order speed * Compute mean for numerical stability * Add release note * Add cosine similarity * Test both dot product and cosine similarity * Add pydocs hook --------- Co-authored-by: Michel Bartels <login@michelbartels.com>
This commit is contained in:
parent
8c017ccc32
commit
540d0fad97
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [../../../haystack/nodes/ranker]
|
||||
modules: ["base", "sentence_transformers", "recentness_ranker"]
|
||||
modules: ["base", "sentence_transformers", "recentness_ranker", "diversity"]
|
||||
ignore_when_discovered: ["__init__"]
|
||||
processors:
|
||||
- type: filter
|
||||
@ -24,4 +24,3 @@ renderer:
|
||||
add_method_class_prefix: true
|
||||
add_member_class_prefix: false
|
||||
filename: ranker_api.md
|
||||
|
||||
|
||||
49
examples/web_lfqa_improved.py
Normal file
49
examples/web_lfqa_improved.py
Normal file
@ -0,0 +1,49 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
from haystack import Pipeline
|
||||
from haystack.nodes import PromptNode, PromptTemplate, TopPSampler, DocumentMerger
|
||||
from haystack.nodes.ranker.diversity import DiversityRanker
|
||||
from haystack.nodes.retriever.web import WebRetriever
|
||||
|
||||
search_key = os.environ.get("SERPERDEV_API_KEY")
|
||||
if not search_key:
|
||||
raise ValueError("Please set the SERPERDEV_API_KEY environment variable")
|
||||
|
||||
openai_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not openai_key:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable")
|
||||
|
||||
prompt_text = """
|
||||
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question.
|
||||
Provide a clear and concise response that summarizes the key points and information presented in the paragraphs.
|
||||
Your answer should be in your own words and be no longer than 50 words.
|
||||
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer:
|
||||
"""
|
||||
|
||||
prompt_node = PromptNode(
|
||||
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt_text), api_key=openai_key, max_length=256
|
||||
)
|
||||
|
||||
web_retriever = WebRetriever(api_key=search_key, top_search_results=10, mode="preprocessed_documents", top_k=25)
|
||||
|
||||
sampler = TopPSampler(top_p=0.95)
|
||||
ranker = DiversityRanker()
|
||||
merger = DocumentMerger(separator="\n\n")
|
||||
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(component=web_retriever, name="Retriever", inputs=["Query"])
|
||||
pipeline.add_node(component=sampler, name="Sampler", inputs=["Retriever"])
|
||||
pipeline.add_node(component=ranker, name="Ranker", inputs=["Sampler"])
|
||||
pipeline.add_node(component=merger, name="Merger", inputs=["Ranker"])
|
||||
pipeline.add_node(component=prompt_node, name="PromptNode", inputs=["Merger"])
|
||||
|
||||
logger = logging.getLogger("boilerpy3")
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
|
||||
questions = ["What are the reasons for long-standing animosities between Russia and Poland?"]
|
||||
|
||||
for q in questions:
|
||||
print(f"Question: {q}")
|
||||
response = pipeline.run(query=q)
|
||||
print(f"Answer: {response['results'][0]}")
|
||||
149
haystack/nodes/ranker/diversity.py
Normal file
149
haystack/nodes/ranker/diversity.py
Normal file
@ -0,0 +1,149 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from haystack.nodes import BaseRanker
|
||||
from haystack.schema import Document
|
||||
from haystack.lazy_imports import LazyImport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
with LazyImport(message="Run 'pip install farm-haystack[inference]'") as torch_and_transformers_import:
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from haystack.modeling.utils import initialize_device_settings # pylint: disable=ungrouped-imports
|
||||
|
||||
|
||||
class DiversityRanker(BaseRanker):
|
||||
"""
|
||||
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
|
||||
of the documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name_or_path: Union[str, Path] = "all-MiniLM-L6-v2",
|
||||
use_gpu: Optional[bool] = True,
|
||||
devices: Optional[List[Union[str, "torch.device"]]] = None,
|
||||
similarity: Literal["dot_product", "cosine"] = "dot_product",
|
||||
):
|
||||
"""
|
||||
Initialize a DiversityRanker.
|
||||
|
||||
:param model_name_or_path: Path to a pretrained sentence-transformers model.
|
||||
:param use_gpu: Whether to use GPU (if available). If no GPUs are available, it falls back on a CPU.
|
||||
:param devices: List of torch devices (for example, cuda:0, cpu, mps) to limit inference to specific devices.
|
||||
:param similarity: Whether to use dot product or cosine similarity. Can be set to "dot_product" (default) or "cosine".
|
||||
"""
|
||||
torch_and_transformers_import.check()
|
||||
super().__init__()
|
||||
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=True)
|
||||
self.model = SentenceTransformer(model_name_or_path, device=str(self.devices[0]))
|
||||
self.similarity = similarity
|
||||
|
||||
def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> List[Document]:
|
||||
"""
|
||||
Rank the documents based on their diversity and return the top_k documents.
|
||||
|
||||
:param query: The query.
|
||||
:param documents: A list of Document objects that should be ranked.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
|
||||
:return: A list of top_k documents ranked based on diversity.
|
||||
"""
|
||||
if query is None or len(query) == 0:
|
||||
raise ValueError("Query is empty")
|
||||
if documents is None or len(documents) == 0:
|
||||
raise ValueError("No documents to choose from")
|
||||
|
||||
diversity_sorted = self.greedy_diversity_order(query=query, documents=documents)
|
||||
return diversity_sorted[:top_k]
|
||||
|
||||
def greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
"""
|
||||
Orders the given list of documents to maximize diversity. The algorithm first calculates embeddings for
|
||||
each document and the query. It starts by selecting the document that is semantically closest to the query.
|
||||
Then, for each remaining document, it selects the one that, on average, is least similar to the already
|
||||
selected documents. This process continues until all documents are selected, resulting in a list where
|
||||
each subsequent document contributes the most to the overall diversity of the selected set.
|
||||
|
||||
:param query: The search query.
|
||||
:param documents: The list of Document objects to be ranked.
|
||||
|
||||
:return: A list of documents ordered to maximize diversity.
|
||||
"""
|
||||
|
||||
# Calculate embeddings
|
||||
doc_embeddings: torch.Tensor = self.model.encode([d.content for d in documents], convert_to_tensor=True)
|
||||
query_embedding: torch.Tensor = self.model.encode([query], convert_to_tensor=True)
|
||||
|
||||
if self.similarity == "dot_product":
|
||||
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
||||
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
||||
|
||||
n = len(documents)
|
||||
selected: List[int] = []
|
||||
|
||||
# Compute the similarity vector between the query and documents
|
||||
query_doc_sim: torch.Tensor = query_embedding @ doc_embeddings.T
|
||||
|
||||
# Start with the document with the highest similarity to the query
|
||||
selected.append(int(torch.argmax(query_doc_sim).item()))
|
||||
|
||||
selected_sum = doc_embeddings[selected[0]] / n
|
||||
|
||||
while len(selected) < n:
|
||||
# Compute mean of dot products of all selected documents and all other documents
|
||||
similarities = selected_sum @ doc_embeddings.T
|
||||
# Mask documents that are already selected
|
||||
similarities[selected] = torch.inf
|
||||
# Select the document with the lowest total similarity score
|
||||
index_unselected = int(torch.argmin(similarities).item())
|
||||
|
||||
selected.append(index_unselected)
|
||||
# It's enough just to add to the selected vectors because dot product is distributive
|
||||
# It's divided by n for numerical stability
|
||||
selected_sum += doc_embeddings[index_unselected] / n
|
||||
|
||||
ranked_docs: List[Document] = [documents[i] for i in selected]
|
||||
|
||||
return ranked_docs
|
||||
|
||||
def predict_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
documents: Union[List[Document], List[List[Document]]],
|
||||
top_k: Optional[float] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> Union[List[Document], List[List[Document]]]:
|
||||
"""
|
||||
Rank the documents based on their diversity and return the top_k documents.
|
||||
|
||||
:param queries: The queries.
|
||||
:param documents: A list (or a list of lists) of Document objects that should be ranked.
|
||||
:param top_k: The maximum number of documents to return.
|
||||
:param batch_size: The number of documents to process in one batch.
|
||||
|
||||
:return: A list (or a list of lists) of top_k documents ranked based on diversity.
|
||||
"""
|
||||
if queries is None or len(queries) == 0:
|
||||
raise ValueError("No queries to choose from")
|
||||
if documents is None or len(documents) == 0:
|
||||
raise ValueError("No documents to choose from")
|
||||
if len(documents) > 0 and isinstance(documents[0], Document):
|
||||
# Docs case 1: single list of Documents -> rerank single list of Documents based on single query
|
||||
if len(queries) != 1:
|
||||
raise ValueError("Number of queries must be 1 if a single list of Documents is provided.")
|
||||
return self.predict(query=queries[0], documents=documents, top_k=top_k) # type: ignore
|
||||
else:
|
||||
# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
|
||||
# If queries contains a single query, apply it to each list of Documents
|
||||
if len(queries) == 1:
|
||||
queries = queries * len(documents)
|
||||
if len(queries) != len(documents):
|
||||
raise ValueError("Number of queries must be equal to number of provided Document lists.")
|
||||
|
||||
results = []
|
||||
for query, cur_docs in zip(queries, documents):
|
||||
results.append(self.predict(query=query, documents=cur_docs, top_k=top_k)) # type: ignore
|
||||
return results
|
||||
@ -0,0 +1,17 @@
|
||||
---
|
||||
prelude: >
|
||||
We're introducing a new ranker to Haystack - DiversityRanker. This
|
||||
ranker aims to maximize the overall diversity of the given documents.
|
||||
It leverages sentence-transformer models to calculate semantic embeddings
|
||||
for each document. It orders documents so that the next one, on average,
|
||||
is least similar to the already selected documents. Such ranking results in a
|
||||
list where each subsequent document contributes the most to the overall
|
||||
diversity of the selected document set.
|
||||
features:
|
||||
- |
|
||||
The DiversityRanker can be used like other rankers in Haystack and
|
||||
it can be particularly helpful in cases where you have highly relevant
|
||||
yet similar sets of documents. By ensuring a diversity of documents,
|
||||
this new ranker facilitates a more comprehensive utilization of the
|
||||
documents and, particularly in RAG pipelines, potentially contributes
|
||||
to more accurate and rich model responses.
|
||||
233
test/nodes/test_diversity_ranker.py
Normal file
233
test/nodes/test_diversity_ranker.py
Normal file
@ -0,0 +1,233 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import Document
|
||||
from haystack.nodes.ranker.diversity import DiversityRanker
|
||||
|
||||
|
||||
# Tests that predict method returns a list of Document objects
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_returns_list_of_documents(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "test query"
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
result = ranker.predict(query=query, documents=documents)
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert all(isinstance(doc, Document) for doc in result)
|
||||
|
||||
|
||||
# Tests that predict method returns the correct number of documents
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_returns_correct_number_of_documents(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "test query"
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
result = ranker.predict(query=query, documents=documents, top_k=1)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# Tests that predict method returns documents in the correct order
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_returns_documents_in_correct_order(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "city"
|
||||
documents = [
|
||||
Document("France"),
|
||||
Document("Germany"),
|
||||
Document("Eiffel Tower"),
|
||||
Document("Berlin"),
|
||||
Document("bananas"),
|
||||
Document("Silicon Valley"),
|
||||
Document("Brandenburg Gate"),
|
||||
]
|
||||
result = ranker.predict(query=query, documents=documents)
|
||||
expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
|
||||
assert ", ".join([doc.content for doc in result]) == expected_order
|
||||
|
||||
|
||||
# Tests that predict_batch method returns a list of lists of Document objects
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_batch_returns_list_of_lists_of_documents(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
queries = ["test query 1", "test query 2"]
|
||||
documents = [
|
||||
[Document(content="doc1"), Document(content="doc2")],
|
||||
[Document(content="doc3"), Document(content="doc4")],
|
||||
]
|
||||
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
|
||||
assert isinstance(result, list)
|
||||
assert all(isinstance(docs, list) for docs in result)
|
||||
assert all(isinstance(doc, Document) for docs in result for doc in docs)
|
||||
|
||||
|
||||
# Tests that predict_batch method returns the correct number of documents
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_batch_returns_correct_number_of_documents(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
queries = ["test query 1", "test query 2"]
|
||||
documents = [
|
||||
[Document(content="doc1"), Document(content="doc2")],
|
||||
[Document(content="doc3"), Document(content="doc4")],
|
||||
]
|
||||
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1)
|
||||
assert len(result) == 2
|
||||
assert len(result[0]) == 1
|
||||
assert len(result[1]) == 1
|
||||
|
||||
|
||||
# Tests that predict_batch method returns documents in the correct order
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_batch_returns_documents_in_correct_order(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
queries = ["Berlin", "Paris"]
|
||||
documents = [
|
||||
[Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")],
|
||||
[Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")],
|
||||
]
|
||||
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
|
||||
assert len(result) == 2
|
||||
|
||||
# check the correct most diverse order are in batches
|
||||
expected_order_0 = "Germany, agriculture, Munich"
|
||||
expected_order_1 = "France, Space exploration, Eiffel Tower"
|
||||
assert ", ".join([doc.content for doc in result[0]]) == expected_order_0
|
||||
assert ", ".join([doc.content for doc in result[1]]) == expected_order_1
|
||||
|
||||
|
||||
# Tests that predict method returns the correct number of documents for a single document
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_single_document_corner_case(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "test"
|
||||
documents = [Document(content="doc1")]
|
||||
result = ranker.predict(query=query, documents=documents)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
# Tests that predict method raises ValueError if query is empty
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_raises_value_error_if_query_is_empty(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = ""
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
with pytest.raises(ValueError):
|
||||
ranker.predict(query=query, documents=documents)
|
||||
|
||||
|
||||
# Tests that predict method raises ValueError if documents is empty
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_raises_value_error_if_documents_is_empty(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "test query"
|
||||
documents = []
|
||||
with pytest.raises(ValueError):
|
||||
ranker.predict(query=query, documents=documents)
|
||||
|
||||
|
||||
# Tests that predict_batch method raises ValueError if queries is empty
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
queries = []
|
||||
documents = [
|
||||
[Document(content="doc1"), Document(content="doc2")],
|
||||
[Document(content="doc3"), Document(content="doc4")],
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
ranker.predict_batch(queries=queries, documents=documents)
|
||||
|
||||
|
||||
# Tests that predict_batch method raises ValueError if documents is empty
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
queries = ["test query 1", "test query 2"]
|
||||
documents = []
|
||||
with pytest.raises(ValueError):
|
||||
ranker.predict_batch(queries=queries, documents=documents)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_predict_real_world_use_case(similarity: str):
|
||||
ranker = DiversityRanker(similarity=similarity) # type: ignore
|
||||
query = "What are the reasons for long-standing animosities between Russia and Poland?"
|
||||
|
||||
doc1 = Document(
|
||||
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
|
||||
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
|
||||
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
|
||||
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
|
||||
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
|
||||
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
|
||||
)
|
||||
|
||||
doc2 = Document(
|
||||
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
|
||||
"Polish–Russian border has mostly been replaced by borders with the respective countries, but there still "
|
||||
"is a 210 km long border between Poland and the Kaliningrad Oblast"
|
||||
)
|
||||
|
||||
doc3 = Document(
|
||||
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
|
||||
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
|
||||
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
|
||||
"Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
|
||||
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
|
||||
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
|
||||
)
|
||||
|
||||
doc4 = Document(
|
||||
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
|
||||
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
|
||||
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
|
||||
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
|
||||
"notorious Sobibor extermination camp."
|
||||
)
|
||||
|
||||
doc5 = Document(
|
||||
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish–Russian "
|
||||
"relations begin with the fall of communism – 1989 in Poland ( Solidarity and the Polish Round Table "
|
||||
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
|
||||
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
|
||||
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
|
||||
"from constant ups and downs."
|
||||
)
|
||||
|
||||
doc6 = Document(
|
||||
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
|
||||
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and "
|
||||
"finally the formal dissolution of the Warsaw Pact."
|
||||
)
|
||||
|
||||
doc7 = Document(
|
||||
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
|
||||
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
|
||||
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
|
||||
"suspicions toward Russia in Poland."
|
||||
)
|
||||
|
||||
doc8 = Document(
|
||||
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and "
|
||||
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
|
||||
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
|
||||
"presence allowed the Soviet Union to heavily influence Polish politics."
|
||||
)
|
||||
|
||||
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
|
||||
result = ranker.predict(query=query, documents=documents)
|
||||
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
|
||||
assert result == expected_order
|
||||
Loading…
x
Reference in New Issue
Block a user