| 
									
										
										
										
											2023-08-01 12:48:34 +02:00
										 |  |  |  | from typing import List | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import pytest | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | from haystack import Document | 
					
						
							|  |  |  |  | from haystack.nodes.ranker.diversity import DiversityRanker | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # Tests that predict method returns a list of Document objects | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_returns_list_of_documents(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "test query" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert isinstance(result, list) | 
					
						
							|  |  |  |  |     assert len(result) == 2 | 
					
						
							|  |  |  |  |     assert all(isinstance(doc, Document) for doc in result) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict method returns the correct number of documents | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_returns_correct_number_of_documents(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "test query" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents, top_k=1) | 
					
						
							|  |  |  |  |     assert len(result) == 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict method returns documents in the correct order | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_returns_documents_in_correct_order(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "city" | 
					
						
							|  |  |  |  |     documents = [ | 
					
						
							|  |  |  |  |         Document("France"), | 
					
						
							|  |  |  |  |         Document("Germany"), | 
					
						
							|  |  |  |  |         Document("Eiffel Tower"), | 
					
						
							|  |  |  |  |         Document("Berlin"), | 
					
						
							|  |  |  |  |         Document("bananas"), | 
					
						
							|  |  |  |  |         Document("Silicon Valley"), | 
					
						
							|  |  |  |  |         Document("Brandenburg Gate"), | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     expected_order = "Berlin, bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany" | 
					
						
							|  |  |  |  |     assert ", ".join([doc.content for doc in result]) == expected_order | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict_batch method returns a list of lists of Document objects | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_batch_returns_list_of_lists_of_documents(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     queries = ["test query 1", "test query 2"] | 
					
						
							|  |  |  |  |     documents = [ | 
					
						
							|  |  |  |  |         [Document(content="doc1"), Document(content="doc2")], | 
					
						
							|  |  |  |  |         [Document(content="doc3"), Document(content="doc4")], | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents) | 
					
						
							|  |  |  |  |     assert isinstance(result, list) | 
					
						
							|  |  |  |  |     assert all(isinstance(docs, list) for docs in result) | 
					
						
							|  |  |  |  |     assert all(isinstance(doc, Document) for docs in result for doc in docs) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict_batch method returns the correct number of documents | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_batch_returns_correct_number_of_documents(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     queries = ["test query 1", "test query 2"] | 
					
						
							|  |  |  |  |     documents = [ | 
					
						
							|  |  |  |  |         [Document(content="doc1"), Document(content="doc2")], | 
					
						
							|  |  |  |  |         [Document(content="doc3"), Document(content="doc4")], | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents, top_k=1) | 
					
						
							|  |  |  |  |     assert len(result) == 2 | 
					
						
							|  |  |  |  |     assert len(result[0]) == 1 | 
					
						
							|  |  |  |  |     assert len(result[1]) == 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict_batch method returns documents in the correct order | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_batch_returns_documents_in_correct_order(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     queries = ["Berlin", "Paris"] | 
					
						
							|  |  |  |  |     documents = [ | 
					
						
							|  |  |  |  |         [Document(content="Germany"), Document(content="Munich"), Document(content="agriculture")], | 
					
						
							|  |  |  |  |         [Document(content="France"), Document(content="Space exploration"), Document(content="Eiffel Tower")], | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     result: List[List[Document]] = ranker.predict_batch(queries=queries, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 2 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # check the correct most diverse order are in batches | 
					
						
							|  |  |  |  |     expected_order_0 = "Germany, agriculture, Munich" | 
					
						
							|  |  |  |  |     expected_order_1 = "France, Space exploration, Eiffel Tower" | 
					
						
							|  |  |  |  |     assert ", ".join([doc.content for doc in result[0]]) == expected_order_0 | 
					
						
							|  |  |  |  |     assert ", ".join([doc.content for doc in result[1]]) == expected_order_1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | # Tests that predict method returns the correct number of documents for a single document | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_single_document_corner_case(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "test" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict method raises ValueError if query is empty | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_raises_value_error_if_query_is_empty(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2")] | 
					
						
							|  |  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |  |         ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict method raises ValueError if documents is empty | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_raises_value_error_if_documents_is_empty(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "test query" | 
					
						
							|  |  |  |  |     documents = [] | 
					
						
							|  |  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |  |         ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict_batch method raises ValueError if queries is empty | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_batch_raises_value_error_if_queries_is_empty(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     queries = [] | 
					
						
							|  |  |  |  |     documents = [ | 
					
						
							|  |  |  |  |         [Document(content="doc1"), Document(content="doc2")], | 
					
						
							|  |  |  |  |         [Document(content="doc3"), Document(content="doc4")], | 
					
						
							|  |  |  |  |     ] | 
					
						
							|  |  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |  |         ranker.predict_batch(queries=queries, documents=documents) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | #  Tests that predict_batch method raises ValueError if documents is empty | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_batch_raises_value_error_if_documents_is_empty(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     queries = ["test query 1", "test query 2"] | 
					
						
							|  |  |  |  |     documents = [] | 
					
						
							|  |  |  |  |     with pytest.raises(ValueError): | 
					
						
							|  |  |  |  |         ranker.predict_batch(queries=queries, documents=documents) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | @pytest.mark.parametrize("similarity", ["dot_product", "cosine"]) | 
					
						
							|  |  |  |  | def test_predict_real_world_use_case(similarity: str): | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity=similarity)  # type: ignore | 
					
						
							|  |  |  |  |     query = "What are the reasons for long-standing animosities between Russia and Poland?" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc1 = Document( | 
					
						
							|  |  |  |  |         "One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , " | 
					
						
							|  |  |  |  |         "Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by " | 
					
						
							|  |  |  |  |         "that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland " | 
					
						
							|  |  |  |  |         "accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was " | 
					
						
							|  |  |  |  |         "Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into " | 
					
						
							|  |  |  |  |         "the Catholic and Orthodox branches separating the Poles from the Eastern Slavs." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc2 = Document( | 
					
						
							|  |  |  |  |         "Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the " | 
					
						
							|  |  |  |  |         "Polish–Russian border has mostly been replaced by borders with the respective countries, but there still " | 
					
						
							|  |  |  |  |         "is a 210 km long border between Poland and the Kaliningrad Oblast" | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc3 = Document( | 
					
						
							|  |  |  |  |         "As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr " | 
					
						
							|  |  |  |  |         "Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of " | 
					
						
							|  |  |  |  |         "the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the " | 
					
						
							|  |  |  |  |         "Stockholm Arbitral Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices " | 
					
						
							|  |  |  |  |         "should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when " | 
					
						
							|  |  |  |  |         "PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc4 = Document( | 
					
						
							|  |  |  |  |         "Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly " | 
					
						
							|  |  |  |  |         "accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in " | 
					
						
							|  |  |  |  |         "2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was " | 
					
						
							|  |  |  |  |         "not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a " | 
					
						
							|  |  |  |  |         "notorious Sobibor extermination camp." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc5 = Document( | 
					
						
							|  |  |  |  |         "President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish–Russian " | 
					
						
							|  |  |  |  |         "relations begin with the fall of communism – 1989 in Poland ( Solidarity and the Polish Round Table " | 
					
						
							|  |  |  |  |         "Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after " | 
					
						
							|  |  |  |  |         "the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly " | 
					
						
							|  |  |  |  |         "independent states , including the Russian Federation . Relations between modern Poland and Russia suffer " | 
					
						
							|  |  |  |  |         "from constant ups and downs." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc6 = Document( | 
					
						
							|  |  |  |  |         "Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections " | 
					
						
							|  |  |  |  |         "in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Bloc , and " | 
					
						
							|  |  |  |  |         "finally the formal dissolution of the Warsaw Pact." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc7 = Document( | 
					
						
							|  |  |  |  |         "Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of " | 
					
						
							|  |  |  |  |         "the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish " | 
					
						
							|  |  |  |  |         "relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase " | 
					
						
							|  |  |  |  |         "suspicions toward Russia in Poland." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     doc8 = Document( | 
					
						
							|  |  |  |  |         "Soviet control over the Polish People's Republic lessened after Stalin's death and Gomułka's Thaw , and " | 
					
						
							|  |  |  |  |         "ceased completely after the fall of the communist government in Poland in late 1989, although the " | 
					
						
							|  |  |  |  |         "Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military " | 
					
						
							|  |  |  |  |         "presence allowed the Soviet Union to heavily influence Polish politics." | 
					
						
							|  |  |  |  |     ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8] | 
					
						
							|  |  |  |  |     assert result == expected_order | 
					
						
							| 
									
										
										
										
											2023-08-02 17:04:04 +02:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | def test_diversity_ranker_with_top_k(): | 
					
						
							|  |  |  |  |     # Tests that predict method returns the correct order of documents | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity="cosine", top_k=1) | 
					
						
							|  |  |  |  |     query = "test" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 1 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | @pytest.mark.integration | 
					
						
							|  |  |  |  | def test_diversity_ranker_with_top_k_edge(): | 
					
						
							|  |  |  |  |     # Tests that predict method returns the correct order of documents for edge cases | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity="cosine", top_k=5) | 
					
						
							|  |  |  |  |     query = "test" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 3 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # negative top_k should return empty list | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity="cosine", top_k=-5) | 
					
						
							|  |  |  |  |     query = "test" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 0 | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  |     # we know None is ignored in slice notation, but let's make sure it works | 
					
						
							|  |  |  |  |     ranker = DiversityRanker(similarity="cosine", top_k=None) | 
					
						
							|  |  |  |  |     query = "test" | 
					
						
							|  |  |  |  |     documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] | 
					
						
							|  |  |  |  |     result = ranker.predict(query=query, documents=documents) | 
					
						
							|  |  |  |  |     assert len(result) == 3 |