mirror of
				https://github.com/deepset-ai/haystack.git
				synced 2025-11-04 11:49:23 +00:00 
			
		
		
		
	refactor!: rename SimilarityRanker to TransformersSimilarityRanker (#6100)
				
					
				
			* rename * release note * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * fix test --------- Co-authored-by: Domenico <domenico.cinque98@gmail.com>
This commit is contained in:
		
							parent
							
								
									1cf70d3dce
								
							
						
					
					
						commit
						1f4ed3cc03
					
				@ -1,3 +1,3 @@
 | 
				
			|||||||
from haystack.preview.components.rankers.similarity import SimilarityRanker
 | 
					from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
__all__ = ["SimilarityRanker"]
 | 
					__all__ = ["TransformersSimilarityRanker"]
 | 
				
			||||||
 | 
				
			|||||||
@ -14,19 +14,20 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.3
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@component
 | 
					@component
 | 
				
			||||||
class SimilarityRanker:
 | 
					class TransformersSimilarityRanker:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Ranks documents based on query similarity.
 | 
					    Ranks documents based on query similarity.
 | 
				
			||||||
 | 
					    It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    Usage example:
 | 
					    Usage example:
 | 
				
			||||||
    ```
 | 
					    ```
 | 
				
			||||||
    from haystack.preview import Document
 | 
					    from haystack.preview import Document
 | 
				
			||||||
    from haystack.preview.components.rankers import SimilarityRanker
 | 
					    from haystack.preview.components.rankers import TransformersSimilarityRanker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    sampler = SimilarityRanker()
 | 
					    ranker = TransformersSimilarityRanker()
 | 
				
			||||||
    docs = [Document(text="Paris"), Document(text="Berlin")]
 | 
					    docs = [Document(text="Paris"), Document(text="Berlin")]
 | 
				
			||||||
    query = "City in Germany"
 | 
					    query = "City in Germany"
 | 
				
			||||||
    output = sampler.run(query=query, documents=docs)
 | 
					    output = ranker.run(query=query, documents=docs)
 | 
				
			||||||
    docs = output["documents"]
 | 
					    docs = output["documents"]
 | 
				
			||||||
    assert len(docs) == 2
 | 
					    assert len(docs) == 2
 | 
				
			||||||
    assert docs[0].text == "Berlin"
 | 
					    assert docs[0].text == "Berlin"
 | 
				
			||||||
@ -41,9 +42,10 @@ class SimilarityRanker:
 | 
				
			|||||||
        top_k: int = 10,
 | 
					        top_k: int = 10,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Creates an instance of SimilarityRanker.
 | 
					        Creates an instance of TransformersSimilarityRanker.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        :param model_name_or_path: Path to a pre-trained sentence-transformers model.
 | 
					        :param model_name_or_path: The name or path of a pre-trained cross-encoder model
 | 
				
			||||||
 | 
					            from Hugging Face Hub.
 | 
				
			||||||
        :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
 | 
					        :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
 | 
				
			||||||
        :param token: The API token used to download private models from Hugging Face.
 | 
					        :param token: The API token used to download private models from Hugging Face.
 | 
				
			||||||
            If this parameter is set to `True`, then the token generated when running
 | 
					            If this parameter is set to `True`, then the token generated when running
 | 
				
			||||||
@ -0,0 +1,5 @@
 | 
				
			|||||||
 | 
					---
 | 
				
			||||||
 | 
					preview:
 | 
				
			||||||
 | 
					  - |
 | 
				
			||||||
 | 
					    Rename `SimilarityRanker` to `TransformersSimilarityRanker`,
 | 
				
			||||||
 | 
					    as there will be more similarity rankers in the future.
 | 
				
			||||||
@ -1,30 +1,32 @@
 | 
				
			|||||||
import pytest
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from haystack.preview import Document, ComponentError
 | 
					from haystack.preview import Document, ComponentError
 | 
				
			||||||
from haystack.preview.components.rankers.similarity import SimilarityRanker
 | 
					from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestSimilarityRanker:
 | 
					class TestSimilarityRanker:
 | 
				
			||||||
    @pytest.mark.unit
 | 
					    @pytest.mark.unit
 | 
				
			||||||
    def test_to_dict(self):
 | 
					    def test_to_dict(self):
 | 
				
			||||||
        component = SimilarityRanker()
 | 
					        component = TransformersSimilarityRanker()
 | 
				
			||||||
        data = component.to_dict()
 | 
					        data = component.to_dict()
 | 
				
			||||||
        assert data == {
 | 
					        assert data == {
 | 
				
			||||||
            "type": "SimilarityRanker",
 | 
					            "type": "TransformersSimilarityRanker",
 | 
				
			||||||
            "init_parameters": {
 | 
					            "init_parameters": {
 | 
				
			||||||
                "device": "cpu",
 | 
					                "device": "cpu",
 | 
				
			||||||
                "top_k": 10,
 | 
					                "top_k": 10,
 | 
				
			||||||
                "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
 | 
					 | 
				
			||||||
                "token": None,
 | 
					                "token": None,
 | 
				
			||||||
 | 
					                "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
 | 
				
			||||||
            },
 | 
					            },
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @pytest.mark.unit
 | 
					    @pytest.mark.unit
 | 
				
			||||||
    def test_to_dict_with_custom_init_parameters(self):
 | 
					    def test_to_dict_with_custom_init_parameters(self):
 | 
				
			||||||
        component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5)
 | 
					        component = TransformersSimilarityRanker(
 | 
				
			||||||
 | 
					            model_name_or_path="my_model", device="cuda", token="my_token", top_k=5
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
        data = component.to_dict()
 | 
					        data = component.to_dict()
 | 
				
			||||||
        assert data == {
 | 
					        assert data == {
 | 
				
			||||||
            "type": "SimilarityRanker",
 | 
					            "type": "TransformersSimilarityRanker",
 | 
				
			||||||
            "init_parameters": {
 | 
					            "init_parameters": {
 | 
				
			||||||
                "device": "cuda",
 | 
					                "device": "cuda",
 | 
				
			||||||
                "model_name_or_path": "my_model",
 | 
					                "model_name_or_path": "my_model",
 | 
				
			||||||
@ -46,7 +48,7 @@ class TestSimilarityRanker:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        Test if the component ranks documents correctly.
 | 
					        Test if the component ranks documents correctly.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
 | 
					        ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
 | 
				
			||||||
        ranker.warm_up()
 | 
					        ranker.warm_up()
 | 
				
			||||||
        docs_before = [Document(text=text) for text in docs_before_texts]
 | 
					        docs_before = [Document(text=text) for text in docs_before_texts]
 | 
				
			||||||
        output = ranker.run(query=query, documents=docs_before)
 | 
					        output = ranker.run(query=query, documents=docs_before)
 | 
				
			||||||
@ -61,7 +63,7 @@ class TestSimilarityRanker:
 | 
				
			|||||||
    #  Returns an empty list if no documents are provided
 | 
					    #  Returns an empty list if no documents are provided
 | 
				
			||||||
    @pytest.mark.integration
 | 
					    @pytest.mark.integration
 | 
				
			||||||
    def test_returns_empty_list_if_no_documents_are_provided(self):
 | 
					    def test_returns_empty_list_if_no_documents_are_provided(self):
 | 
				
			||||||
        sampler = SimilarityRanker()
 | 
					        sampler = TransformersSimilarityRanker()
 | 
				
			||||||
        sampler.warm_up()
 | 
					        sampler.warm_up()
 | 
				
			||||||
        output = sampler.run(query="City in Germany", documents=[])
 | 
					        output = sampler.run(query="City in Germany", documents=[])
 | 
				
			||||||
        assert output["documents"] == []
 | 
					        assert output["documents"] == []
 | 
				
			||||||
@ -69,7 +71,7 @@ class TestSimilarityRanker:
 | 
				
			|||||||
    #  Raises ComponentError if model is not warmed up
 | 
					    #  Raises ComponentError if model is not warmed up
 | 
				
			||||||
    @pytest.mark.integration
 | 
					    @pytest.mark.integration
 | 
				
			||||||
    def test_raises_component_error_if_model_not_warmed_up(self):
 | 
					    def test_raises_component_error_if_model_not_warmed_up(self):
 | 
				
			||||||
        sampler = SimilarityRanker()
 | 
					        sampler = TransformersSimilarityRanker()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        with pytest.raises(ComponentError):
 | 
					        with pytest.raises(ComponentError):
 | 
				
			||||||
            sampler.run(query="query", documents=[Document(text="document")])
 | 
					            sampler.run(query="query", documents=[Document(text="document")])
 | 
				
			||||||
@ -87,7 +89,7 @@ class TestSimilarityRanker:
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        Test if the component ranks documents correctly with a custom top_k.
 | 
					        Test if the component ranks documents correctly with a custom top_k.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
 | 
					        ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
 | 
				
			||||||
        ranker.warm_up()
 | 
					        ranker.warm_up()
 | 
				
			||||||
        docs_before = [Document(text=text) for text in docs_before_texts]
 | 
					        docs_before = [Document(text=text) for text in docs_before_texts]
 | 
				
			||||||
        output = ranker.run(query=query, documents=docs_before)
 | 
					        output = ranker.run(query=query, documents=docs_before)
 | 
				
			||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user