| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  | import pytest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from haystack.preview import Document, ComponentError | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  | from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class TestSimilarityRanker: | 
					
						
							|  |  |  |     @pytest.mark.unit | 
					
						
							|  |  |  |     def test_to_dict(self): | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         component = TransformersSimilarityRanker() | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |             "type": "TransformersSimilarityRanker", | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |             "init_parameters": { | 
					
						
							|  |  |  |                 "device": "cpu", | 
					
						
							|  |  |  |                 "top_k": 10, | 
					
						
							| 
									
										
										
										
											2023-10-17 16:32:13 +02:00
										 |  |  |                 "token": None, | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |                 "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |             }, | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.unit | 
					
						
							|  |  |  |     def test_to_dict_with_custom_init_parameters(self): | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         component = TransformersSimilarityRanker( | 
					
						
							|  |  |  |             model_name_or_path="my_model", device="cuda", token="my_token", top_k=5 | 
					
						
							|  |  |  |         ) | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         data = component.to_dict() | 
					
						
							|  |  |  |         assert data == { | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |             "type": "TransformersSimilarityRanker", | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |             "init_parameters": { | 
					
						
							| 
									
										
										
										
											2023-10-17 16:32:13 +02:00
										 |  |  |                 "device": "cuda", | 
					
						
							|  |  |  |                 "model_name_or_path": "my_model", | 
					
						
							|  |  |  |                 "token": None,  # we don't serialize valid tokens, | 
					
						
							|  |  |  |                 "top_k": 5, | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |             }, | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     @pytest.mark.parametrize( | 
					
						
							|  |  |  |         "query,docs_before_texts,expected_first_text", | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), | 
					
						
							|  |  |  |             ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), | 
					
						
							|  |  |  |             ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     def test_run(self, query, docs_before_texts, expected_first_text): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Test if the component ranks documents correctly. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         ranker.warm_up() | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |         docs_before = [Document(content=text) for text in docs_before_texts] | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         output = ranker.run(query=query, documents=docs_before) | 
					
						
							|  |  |  |         docs_after = output["documents"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert len(docs_after) == 3 | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |         assert docs_after[0].content == expected_first_text | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) | 
					
						
							|  |  |  |         assert [doc.score for doc in docs_after] == sorted_scores | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     #  Returns an empty list if no documents are provided | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_returns_empty_list_if_no_documents_are_provided(self): | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         sampler = TransformersSimilarityRanker() | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  |         sampler.warm_up() | 
					
						
							|  |  |  |         output = sampler.run(query="City in Germany", documents=[]) | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |         assert not output["documents"] | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     #  Raises ComponentError if model is not warmed up | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     def test_raises_component_error_if_model_not_warmed_up(self): | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         sampler = TransformersSimilarityRanker() | 
					
						
							| 
									
										
										
										
											2023-10-06 16:01:34 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         with pytest.raises(ComponentError): | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |             sampler.run(query="query", documents=[Document(content="document")]) | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @pytest.mark.integration | 
					
						
							|  |  |  |     @pytest.mark.parametrize( | 
					
						
							|  |  |  |         "query,docs_before_texts,expected_first_text", | 
					
						
							|  |  |  |         [ | 
					
						
							|  |  |  |             ("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"), | 
					
						
							|  |  |  |             ("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"), | 
					
						
							|  |  |  |             ("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"), | 
					
						
							|  |  |  |         ], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     def test_run_top_k(self, query, docs_before_texts, expected_first_text): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Test if the component ranks documents correctly with a custom top_k. | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2023-10-24 19:45:16 +02:00
										 |  |  |         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |         ranker.warm_up() | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |         docs_before = [Document(content=text) for text in docs_before_texts] | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  |         output = ranker.run(query=query, documents=docs_before) | 
					
						
							|  |  |  |         docs_after = output["documents"] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         assert len(docs_after) == 2 | 
					
						
							| 
									
										
										
										
											2023-10-31 12:44:04 +01:00
										 |  |  |         assert docs_after[0].content == expected_first_text | 
					
						
							| 
									
										
										
										
											2023-10-12 13:52:01 +02:00
										 |  |  | 
 | 
					
						
							|  |  |  |         sorted_scores = sorted([doc.score for doc in docs_after], reverse=True) | 
					
						
							|  |  |  |         assert [doc.score for doc in docs_after] == sorted_scores |