haystack/test/nodes/test_diversity_ranker.py

268 lines
13 KiB
Python
Raw Normal View History

from typing import List
import pytest
from haystack import Document
from haystack.nodes.ranker.diversity import DiversityRanker
# Tests that predict method returns a list of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_list_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(doc, Document) for doc in result)
# Tests that predict method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents, top_k=1)
assert len(result) == 1
# Tests that predict method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "city"
documents = [
Document("France"),
Document("Germany"),
Document("Eiffel Tower"),
Document("Berlin"),
Document("bananas"),
Document("Silicon Valley"),
Document("Brandenburg Gate"),
]
result = ranker.predict(query=query, documents=documents)
expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
assert ", ".join([doc.content for doc in result]) == expected_order
# Tests that predict_batch method returns a list of lists of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_list_of_lists_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert isinstance(result, list)
assert all(isinstance(docs, list) for docs in result)
assert all(isinstance(doc, Document) for docs in result for doc in docs)
# Tests that predict_batch method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1)
assert len(result) == 2
assert len(result[0]) == 1
assert len(result[1]) == 1
# Tests that predict_batch method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["Berlin", "Paris"]
documents = [
[Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")],
[Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert len(result) == 2
# check the correct most diverse order are in batches
expected_order_0 = "Germany, agriculture, Munich"
expected_order_1 = "France, Space exploration, Eiffel Tower"
assert ", ".join([doc.content for doc in result[0]]) == expected_order_0
assert ", ".join([doc.content for doc in result[1]]) == expected_order_1
# Tests that predict method returns the correct number of documents for a single document
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_single_document_corner_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test"
documents = [Document(content="doc1")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 1
# Tests that predict method raises ValueError if query is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_query_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = ""
documents = [Document(content="doc1"), Document(content="doc2")]
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = []
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict_batch method raises ValueError if queries is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = []
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
# Tests that predict_batch method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = []
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_real_world_use_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "What are the reasons for long-standing animosities between Russia and Poland?"
doc1 = Document(
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
)
doc2 = Document(
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
"PolishRussian border has mostly been replaced by borders with the respective countries, but there still "
"is a 210 km long border between Poland and the Kaliningrad Oblast"
)
doc3 = Document(
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
"Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
)
doc4 = Document(
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
"notorious Sobibor extermination camp."
)
doc5 = Document(
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern PolishRussian "
"relations begin with the fall of communism 1989 in Poland ( Solidarity and the Polish Round Table "
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
"from constant ups and downs."
)
doc6 = Document(
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and "
"finally the formal dissolution of the Warsaw Pact."
)
doc7 = Document(
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
"suspicions toward Russia in Poland."
)
doc8 = Document(
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and "
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
"presence allowed the Soviet Union to heavily influence Polish politics."
)
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
result = ranker.predict(query=query, documents=documents)
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
assert result == expected_order
@pytest.mark.integration
def test_diversity_ranker_with_top_k():
# Tests that predict method returns the correct order of documents
ranker = DiversityRanker(similarity="cosine", top_k=1)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 1
@pytest.mark.integration
def test_diversity_ranker_with_top_k_edge():
# Tests that predict method returns the correct order of documents for edge cases
ranker = DiversityRanker(similarity="cosine", top_k=5)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 3
# negative top_k should return empty list
ranker = DiversityRanker(similarity="cosine", top_k=-5)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 0
# we know None is ignored in slice notation, but let's make sure it works
ranker = DiversityRanker(similarity="cosine", top_k=None)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 3