mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	refactor!: rename SimilarityRanker to TransformersSimilarityRanker (#6100)
				
					
				
			* rename * release note * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * fix test --------- Co-authored-by: Domenico <domenico.cinque98@gmail.com>
This commit is contained in:
		
							parent
							
								
									1cf70d3dce
								
							
						
					
					
						commit
						1f4ed3cc03
					
				| @ -1,3 +1,3 @@ | ||||
| from haystack.preview.components.rankers.similarity import SimilarityRanker | ||||
| from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker | ||||
| 
 | ||||
| __all__ = ["SimilarityRanker"] | ||||
| __all__ = ["TransformersSimilarityRanker"] | ||||
|  | ||||
| @ -14,19 +14,20 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.3 | ||||
| 
 | ||||
| 
 | ||||
| @component | ||||
| class SimilarityRanker: | ||||
| class TransformersSimilarityRanker: | ||||
|     """ | ||||
|     Ranks documents based on query similarity. | ||||
|     It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents. | ||||
| 
 | ||||
|     Usage example: | ||||
|     ``` | ||||
|     from haystack.preview import Document | ||||
|     from haystack.preview.components.rankers import SimilarityRanker | ||||
|     from haystack.preview.components.rankers import TransformersSimilarityRanker | ||||
| 
 | ||||
|     sampler = SimilarityRanker() | ||||
|     ranker = TransformersSimilarityRanker() | ||||
|     docs = [Document(text="Paris"), Document(text="Berlin")] | ||||
|     query = "City in Germany" | ||||
|     output = sampler.run(query=query, documents=docs) | ||||
|     output = ranker.run(query=query, documents=docs) | ||||
|     docs = output["documents"] | ||||
|     assert len(docs) == 2 | ||||
|     assert docs[0].text == "Berlin" | ||||
| @ -41,9 +42,10 @@ class SimilarityRanker: | ||||
|         top_k: int = 10, | ||||
|     ): | ||||
|         """ | ||||
|         Creates an instance of SimilarityRanker. | ||||
|         Creates an instance of TransformersSimilarityRanker. | ||||
| 
 | ||||
|         :param model_name_or_path: Path to a pre-trained sentence-transformers model. | ||||
|         :param model_name_or_path: The name or path of a pre-trained cross-encoder model | ||||
|             from Hugging Face Hub. | ||||
|         :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. | ||||
|         :param token: The API token used to download private models from Hugging Face. | ||||
|             If this parameter is set to `True`, then the token generated when running | ||||
| @ -0,0 +1,5 @@ | ||||
| --- | ||||
| preview: | ||||
|   - | | ||||
|     Rename `SimilarityRanker` to `TransformersSimilarityRanker`, | ||||
|     as there will be more similarity rankers in the future. | ||||
| @ -1,30 +1,32 @@ | ||||
| import pytest | ||||
| 
 | ||||
| from haystack.preview import Document, ComponentError | ||||
| from haystack.preview.components.rankers.similarity import SimilarityRanker | ||||
| from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker | ||||
| 
 | ||||
| 
 | ||||
| class TestSimilarityRanker: | ||||
|     @pytest.mark.unit | ||||
|     def test_to_dict(self): | ||||
|         component = SimilarityRanker() | ||||
|         component = TransformersSimilarityRanker() | ||||
|         data = component.to_dict() | ||||
|         assert data == { | ||||
|             "type": "SimilarityRanker", | ||||
|             "type": "TransformersSimilarityRanker", | ||||
|             "init_parameters": { | ||||
|                 "device": "cpu", | ||||
|                 "top_k": 10, | ||||
|                 "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", | ||||
|                 "token": None, | ||||
|                 "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     @pytest.mark.unit | ||||
|     def test_to_dict_with_custom_init_parameters(self): | ||||
|         component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5) | ||||
|         component = TransformersSimilarityRanker( | ||||
|             model_name_or_path="my_model", device="cuda", token="my_token", top_k=5 | ||||
|         ) | ||||
|         data = component.to_dict() | ||||
|         assert data == { | ||||
|             "type": "SimilarityRanker", | ||||
|             "type": "TransformersSimilarityRanker", | ||||
|             "init_parameters": { | ||||
|                 "device": "cuda", | ||||
|                 "model_name_or_path": "my_model", | ||||
| @ -46,7 +48,7 @@ class TestSimilarityRanker: | ||||
|         """ | ||||
|         Test if the component ranks documents correctly. | ||||
|         """ | ||||
|         ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") | ||||
|         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") | ||||
|         ranker.warm_up() | ||||
|         docs_before = [Document(text=text) for text in docs_before_texts] | ||||
|         output = ranker.run(query=query, documents=docs_before) | ||||
| @ -61,7 +63,7 @@ class TestSimilarityRanker: | ||||
|     #  Returns an empty list if no documents are provided | ||||
|     @pytest.mark.integration | ||||
|     def test_returns_empty_list_if_no_documents_are_provided(self): | ||||
|         sampler = SimilarityRanker() | ||||
|         sampler = TransformersSimilarityRanker() | ||||
|         sampler.warm_up() | ||||
|         output = sampler.run(query="City in Germany", documents=[]) | ||||
|         assert output["documents"] == [] | ||||
| @ -69,7 +71,7 @@ class TestSimilarityRanker: | ||||
|     #  Raises ComponentError if model is not warmed up | ||||
|     @pytest.mark.integration | ||||
|     def test_raises_component_error_if_model_not_warmed_up(self): | ||||
|         sampler = SimilarityRanker() | ||||
|         sampler = TransformersSimilarityRanker() | ||||
| 
 | ||||
|         with pytest.raises(ComponentError): | ||||
|             sampler.run(query="query", documents=[Document(text="document")]) | ||||
| @ -87,7 +89,7 @@ class TestSimilarityRanker: | ||||
|         """ | ||||
|         Test if the component ranks documents correctly with a custom top_k. | ||||
|         """ | ||||
|         ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) | ||||
|         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) | ||||
|         ranker.warm_up() | ||||
|         docs_before = [Document(text=text) for text in docs_before_texts] | ||||
|         output = ranker.run(query=query, documents=docs_before) | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Stefano Fiorucci
						Stefano Fiorucci