mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-10-31 09:49:48 +00:00 
			
		
		
		
	refactor!: rename SimilarityRanker to TransformersSimilarityRanker (#6100)
				
					
				
			* rename * release note * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * fix test --------- Co-authored-by: Domenico <domenico.cinque98@gmail.com>
This commit is contained in:
		
							parent
							
								
									1cf70d3dce
								
							
						
					
					
						commit
						1f4ed3cc03
					
				| @ -1,3 +1,3 @@ | |||||||
| from haystack.preview.components.rankers.similarity import SimilarityRanker | from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker | ||||||
| 
 | 
 | ||||||
| __all__ = ["SimilarityRanker"] | __all__ = ["TransformersSimilarityRanker"] | ||||||
|  | |||||||
| @ -14,19 +14,20 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.3 | |||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @component | @component | ||||||
| class SimilarityRanker: | class TransformersSimilarityRanker: | ||||||
|     """ |     """ | ||||||
|     Ranks documents based on query similarity. |     Ranks documents based on query similarity. | ||||||
|  |     It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents. | ||||||
| 
 | 
 | ||||||
|     Usage example: |     Usage example: | ||||||
|     ``` |     ``` | ||||||
|     from haystack.preview import Document |     from haystack.preview import Document | ||||||
|     from haystack.preview.components.rankers import SimilarityRanker |     from haystack.preview.components.rankers import TransformersSimilarityRanker | ||||||
| 
 | 
 | ||||||
|     sampler = SimilarityRanker() |     ranker = TransformersSimilarityRanker() | ||||||
|     docs = [Document(text="Paris"), Document(text="Berlin")] |     docs = [Document(text="Paris"), Document(text="Berlin")] | ||||||
|     query = "City in Germany" |     query = "City in Germany" | ||||||
|     output = sampler.run(query=query, documents=docs) |     output = ranker.run(query=query, documents=docs) | ||||||
|     docs = output["documents"] |     docs = output["documents"] | ||||||
|     assert len(docs) == 2 |     assert len(docs) == 2 | ||||||
|     assert docs[0].text == "Berlin" |     assert docs[0].text == "Berlin" | ||||||
| @ -41,9 +42,10 @@ class SimilarityRanker: | |||||||
|         top_k: int = 10, |         top_k: int = 10, | ||||||
|     ): |     ): | ||||||
|         """ |         """ | ||||||
|         Creates an instance of SimilarityRanker. |         Creates an instance of TransformersSimilarityRanker. | ||||||
| 
 | 
 | ||||||
|         :param model_name_or_path: Path to a pre-trained sentence-transformers model. |         :param model_name_or_path: The name or path of a pre-trained cross-encoder model | ||||||
|  |             from Hugging Face Hub. | ||||||
|         :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. |         :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. | ||||||
|         :param token: The API token used to download private models from Hugging Face. |         :param token: The API token used to download private models from Hugging Face. | ||||||
|             If this parameter is set to `True`, then the token generated when running |             If this parameter is set to `True`, then the token generated when running | ||||||
| @ -0,0 +1,5 @@ | |||||||
|  | --- | ||||||
|  | preview: | ||||||
|  |   - | | ||||||
|  |     Rename `SimilarityRanker` to `TransformersSimilarityRanker`, | ||||||
|  |     as there will be more similarity rankers in the future. | ||||||
| @ -1,30 +1,32 @@ | |||||||
| import pytest | import pytest | ||||||
| 
 | 
 | ||||||
| from haystack.preview import Document, ComponentError | from haystack.preview import Document, ComponentError | ||||||
| from haystack.preview.components.rankers.similarity import SimilarityRanker | from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class TestSimilarityRanker: | class TestSimilarityRanker: | ||||||
|     @pytest.mark.unit |     @pytest.mark.unit | ||||||
|     def test_to_dict(self): |     def test_to_dict(self): | ||||||
|         component = SimilarityRanker() |         component = TransformersSimilarityRanker() | ||||||
|         data = component.to_dict() |         data = component.to_dict() | ||||||
|         assert data == { |         assert data == { | ||||||
|             "type": "SimilarityRanker", |             "type": "TransformersSimilarityRanker", | ||||||
|             "init_parameters": { |             "init_parameters": { | ||||||
|                 "device": "cpu", |                 "device": "cpu", | ||||||
|                 "top_k": 10, |                 "top_k": 10, | ||||||
|                 "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", |  | ||||||
|                 "token": None, |                 "token": None, | ||||||
|  |                 "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2", | ||||||
|             }, |             }, | ||||||
|         } |         } | ||||||
| 
 | 
 | ||||||
|     @pytest.mark.unit |     @pytest.mark.unit | ||||||
|     def test_to_dict_with_custom_init_parameters(self): |     def test_to_dict_with_custom_init_parameters(self): | ||||||
|         component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5) |         component = TransformersSimilarityRanker( | ||||||
|  |             model_name_or_path="my_model", device="cuda", token="my_token", top_k=5 | ||||||
|  |         ) | ||||||
|         data = component.to_dict() |         data = component.to_dict() | ||||||
|         assert data == { |         assert data == { | ||||||
|             "type": "SimilarityRanker", |             "type": "TransformersSimilarityRanker", | ||||||
|             "init_parameters": { |             "init_parameters": { | ||||||
|                 "device": "cuda", |                 "device": "cuda", | ||||||
|                 "model_name_or_path": "my_model", |                 "model_name_or_path": "my_model", | ||||||
| @ -46,7 +48,7 @@ class TestSimilarityRanker: | |||||||
|         """ |         """ | ||||||
|         Test if the component ranks documents correctly. |         Test if the component ranks documents correctly. | ||||||
|         """ |         """ | ||||||
|         ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") |         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") | ||||||
|         ranker.warm_up() |         ranker.warm_up() | ||||||
|         docs_before = [Document(text=text) for text in docs_before_texts] |         docs_before = [Document(text=text) for text in docs_before_texts] | ||||||
|         output = ranker.run(query=query, documents=docs_before) |         output = ranker.run(query=query, documents=docs_before) | ||||||
| @ -61,7 +63,7 @@ class TestSimilarityRanker: | |||||||
|     #  Returns an empty list if no documents are provided |     #  Returns an empty list if no documents are provided | ||||||
|     @pytest.mark.integration |     @pytest.mark.integration | ||||||
|     def test_returns_empty_list_if_no_documents_are_provided(self): |     def test_returns_empty_list_if_no_documents_are_provided(self): | ||||||
|         sampler = SimilarityRanker() |         sampler = TransformersSimilarityRanker() | ||||||
|         sampler.warm_up() |         sampler.warm_up() | ||||||
|         output = sampler.run(query="City in Germany", documents=[]) |         output = sampler.run(query="City in Germany", documents=[]) | ||||||
|         assert output["documents"] == [] |         assert output["documents"] == [] | ||||||
| @ -69,7 +71,7 @@ class TestSimilarityRanker: | |||||||
|     #  Raises ComponentError if model is not warmed up |     #  Raises ComponentError if model is not warmed up | ||||||
|     @pytest.mark.integration |     @pytest.mark.integration | ||||||
|     def test_raises_component_error_if_model_not_warmed_up(self): |     def test_raises_component_error_if_model_not_warmed_up(self): | ||||||
|         sampler = SimilarityRanker() |         sampler = TransformersSimilarityRanker() | ||||||
| 
 | 
 | ||||||
|         with pytest.raises(ComponentError): |         with pytest.raises(ComponentError): | ||||||
|             sampler.run(query="query", documents=[Document(text="document")]) |             sampler.run(query="query", documents=[Document(text="document")]) | ||||||
| @ -87,7 +89,7 @@ class TestSimilarityRanker: | |||||||
|         """ |         """ | ||||||
|         Test if the component ranks documents correctly with a custom top_k. |         Test if the component ranks documents correctly with a custom top_k. | ||||||
|         """ |         """ | ||||||
|         ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) |         ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) | ||||||
|         ranker.warm_up() |         ranker.warm_up() | ||||||
|         docs_before = [Document(text=text) for text in docs_before_texts] |         docs_before = [Document(text=text) for text in docs_before_texts] | ||||||
|         output = ranker.run(query=query, documents=docs_before) |         output = ranker.run(query=query, documents=docs_before) | ||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Stefano Fiorucci
						Stefano Fiorucci