haystack/test/nodes/test_diversity_ranker.py
Vladimir Blagojevic 0efe0ee7b3
feat: Add top_k parameter to DiversityRanker init method (#5494)
* Add top_k

* Add release note
2023-08-02 17:04:04 +02:00

268 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
import pytest
from haystack import Document
from haystack.nodes.ranker.diversity import DiversityRanker
# Tests that predict method returns a list of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_list_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents)
assert isinstance(result, list)
assert len(result) == 2
assert all(isinstance(doc, Document) for doc in result)
# Tests that predict method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.predict(query=query, documents=documents, top_k=1)
assert len(result) == 1
# Tests that predict method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "city"
documents = [
Document("France"),
Document("Germany"),
Document("Eiffel Tower"),
Document("Berlin"),
Document("bananas"),
Document("Silicon Valley"),
Document("Brandenburg Gate"),
]
result = ranker.predict(query=query, documents=documents)
expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
assert ", ".join([doc.content for doc in result]) == expected_order
# Tests that predict_batch method returns a list of lists of Document objects
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_list_of_lists_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert isinstance(result, list)
assert all(isinstance(docs, list) for docs in result)
assert all(isinstance(doc, Document) for docs in result for doc in docs)
# Tests that predict_batch method returns the correct number of documents
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_correct_number_of_documents(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1)
assert len(result) == 2
assert len(result[0]) == 1
assert len(result[1]) == 1
# Tests that predict_batch method returns documents in the correct order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_returns_documents_in_correct_order(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["Berlin", "Paris"]
documents = [
[Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")],
[Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")],
]
result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents)
assert len(result) == 2
# check the correct most diverse order are in batches
expected_order_0 = "Germany, agriculture, Munich"
expected_order_1 = "France, Space exploration, Eiffel Tower"
assert ", ".join([doc.content for doc in result[0]]) == expected_order_0
assert ", ".join([doc.content for doc in result[1]]) == expected_order_1
# Tests that predict method returns the correct number of documents for a single document
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_single_document_corner_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test"
documents = [Document(content="doc1")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 1
# Tests that predict method raises ValueError if query is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_query_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = ""
documents = [Document(content="doc1"), Document(content="doc2")]
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "test query"
documents = []
with pytest.raises(ValueError):
ranker.predict(query=query, documents=documents)
# Tests that predict_batch method raises ValueError if queries is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = []
documents = [
[Document(content="doc1"), Document(content="doc2")],
[Document(content="doc3"), Document(content="doc4")],
]
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
# Tests that predict_batch method raises ValueError if documents is empty
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
queries = ["test query 1", "test query 2"]
documents = []
with pytest.raises(ValueError):
ranker.predict_batch(queries=queries, documents=documents)
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_predict_real_world_use_case(similarity: str):
ranker = DiversityRanker(similarity=similarity) # type: ignore
query = "What are the reasons for long-standing animosities between Russia and Poland?"
doc1 = Document(
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
)
doc2 = Document(
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
"PolishRussian border has mostly been replaced by borders with the respective countries, but there still "
"is a 210 km long border between Poland and the Kaliningrad Oblast"
)
doc3 = Document(
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
"Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
)
doc4 = Document(
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
"notorious Sobibor extermination camp."
)
doc5 = Document(
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern PolishRussian "
"relations begin with the fall of communism 1989 in Poland ( Solidarity and the Polish Round Table "
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
"from constant ups and downs."
)
doc6 = Document(
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and "
"finally the formal dissolution of the Warsaw Pact."
)
doc7 = Document(
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
"suspicions toward Russia in Poland."
)
doc8 = Document(
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and "
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
"presence allowed the Soviet Union to heavily influence Polish politics."
)
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
result = ranker.predict(query=query, documents=documents)
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
assert result == expected_order
@pytest.mark.integration
def test_diversity_ranker_with_top_k():
# Tests that predict method returns the correct order of documents
ranker = DiversityRanker(similarity="cosine", top_k=1)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 1
@pytest.mark.integration
def test_diversity_ranker_with_top_k_edge():
# Tests that predict method returns the correct order of documents for edge cases
ranker = DiversityRanker(similarity="cosine", top_k=5)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 3
# negative top_k should return empty list
ranker = DiversityRanker(similarity="cosine", top_k=-5)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 0
# we know None is ignored in slice notation, but let's make sure it works
ranker = DiversityRanker(similarity="cosine", top_k=None)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.predict(query=query, documents=documents)
assert len(result) == 3