refactor!: rename SimilarityRanker to TransformersSimilarityRanker (#6100)

* rename

* release note

* Update haystack/preview/components/rankers/transformers_similarity.py

Co-authored-by: Domenico <domenico.cinque98@gmail.com>

* Update haystack/preview/components/rankers/transformers_similarity.py

Co-authored-by: Domenico <domenico.cinque98@gmail.com>

* fix test

---------

Co-authored-by: Domenico <domenico.cinque98@gmail.com>
This commit is contained in:
Stefano Fiorucci 2023-10-24 19:45:16 +02:00 committed by GitHub
parent 1cf70d3dce
commit 1f4ed3cc03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 27 additions and 18 deletions

View File

@ -1,3 +1,3 @@
from haystack.preview.components.rankers.similarity import SimilarityRanker from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
__all__ = ["SimilarityRanker"] __all__ = ["TransformersSimilarityRanker"]

View File

@ -14,19 +14,20 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.3
@component @component
class SimilarityRanker: class TransformersSimilarityRanker:
""" """
Ranks documents based on query similarity. Ranks documents based on query similarity.
It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents.
Usage example: Usage example:
``` ```
from haystack.preview import Document from haystack.preview import Document
from haystack.preview.components.rankers import SimilarityRanker from haystack.preview.components.rankers import TransformersSimilarityRanker
sampler = SimilarityRanker() ranker = TransformersSimilarityRanker()
docs = [Document(text="Paris"), Document(text="Berlin")] docs = [Document(text="Paris"), Document(text="Berlin")]
query = "City in Germany" query = "City in Germany"
output = sampler.run(query=query, documents=docs) output = ranker.run(query=query, documents=docs)
docs = output["documents"] docs = output["documents"]
assert len(docs) == 2 assert len(docs) == 2
assert docs[0].text == "Berlin" assert docs[0].text == "Berlin"
@ -41,9 +42,10 @@ class SimilarityRanker:
top_k: int = 10, top_k: int = 10,
): ):
""" """
Creates an instance of SimilarityRanker. Creates an instance of TransformersSimilarityRanker.
:param model_name_or_path: Path to a pre-trained sentence-transformers model. :param model_name_or_path: The name or path of a pre-trained cross-encoder model
from Hugging Face Hub.
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
:param token: The API token used to download private models from Hugging Face. :param token: The API token used to download private models from Hugging Face.
If this parameter is set to `True`, then the token generated when running If this parameter is set to `True`, then the token generated when running

View File

@ -0,0 +1,5 @@
---
preview:
- |
Rename `SimilarityRanker` to `TransformersSimilarityRanker`,
as there will be more similarity rankers in the future.

View File

@ -1,30 +1,32 @@
import pytest import pytest
from haystack.preview import Document, ComponentError from haystack.preview import Document, ComponentError
from haystack.preview.components.rankers.similarity import SimilarityRanker from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
class TestSimilarityRanker: class TestSimilarityRanker:
@pytest.mark.unit @pytest.mark.unit
def test_to_dict(self): def test_to_dict(self):
component = SimilarityRanker() component = TransformersSimilarityRanker()
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "SimilarityRanker", "type": "TransformersSimilarityRanker",
"init_parameters": { "init_parameters": {
"device": "cpu", "device": "cpu",
"top_k": 10, "top_k": 10,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"token": None, "token": None,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
}, },
} }
@pytest.mark.unit @pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self): def test_to_dict_with_custom_init_parameters(self):
component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5) component = TransformersSimilarityRanker(
model_name_or_path="my_model", device="cuda", token="my_token", top_k=5
)
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "SimilarityRanker", "type": "TransformersSimilarityRanker",
"init_parameters": { "init_parameters": {
"device": "cuda", "device": "cuda",
"model_name_or_path": "my_model", "model_name_or_path": "my_model",
@ -46,7 +48,7 @@ class TestSimilarityRanker:
""" """
Test if the component ranks documents correctly. Test if the component ranks documents correctly.
""" """
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2") ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
ranker.warm_up() ranker.warm_up()
docs_before = [Document(text=text) for text in docs_before_texts] docs_before = [Document(text=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before) output = ranker.run(query=query, documents=docs_before)
@ -61,7 +63,7 @@ class TestSimilarityRanker:
# Returns an empty list if no documents are provided # Returns an empty list if no documents are provided
@pytest.mark.integration @pytest.mark.integration
def test_returns_empty_list_if_no_documents_are_provided(self): def test_returns_empty_list_if_no_documents_are_provided(self):
sampler = SimilarityRanker() sampler = TransformersSimilarityRanker()
sampler.warm_up() sampler.warm_up()
output = sampler.run(query="City in Germany", documents=[]) output = sampler.run(query="City in Germany", documents=[])
assert output["documents"] == [] assert output["documents"] == []
@ -69,7 +71,7 @@ class TestSimilarityRanker:
# Raises ComponentError if model is not warmed up # Raises ComponentError if model is not warmed up
@pytest.mark.integration @pytest.mark.integration
def test_raises_component_error_if_model_not_warmed_up(self): def test_raises_component_error_if_model_not_warmed_up(self):
sampler = SimilarityRanker() sampler = TransformersSimilarityRanker()
with pytest.raises(ComponentError): with pytest.raises(ComponentError):
sampler.run(query="query", documents=[Document(text="document")]) sampler.run(query="query", documents=[Document(text="document")])
@ -87,7 +89,7 @@ class TestSimilarityRanker:
""" """
Test if the component ranks documents correctly with a custom top_k. Test if the component ranks documents correctly with a custom top_k.
""" """
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2) ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
ranker.warm_up() ranker.warm_up()
docs_before = [Document(text=text) for text in docs_before_texts] docs_before = [Document(text=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before) output = ranker.run(query=query, documents=docs_before)