mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-01 04:13:26 +00:00
refactor!: rename SimilarityRanker
to TransformersSimilarityRanker
(#6100)
* rename * release note * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * Update haystack/preview/components/rankers/transformers_similarity.py Co-authored-by: Domenico <domenico.cinque98@gmail.com> * fix test --------- Co-authored-by: Domenico <domenico.cinque98@gmail.com>
This commit is contained in:
parent
1cf70d3dce
commit
1f4ed3cc03
@ -1,3 +1,3 @@
|
|||||||
from haystack.preview.components.rankers.similarity import SimilarityRanker
|
from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||||
|
|
||||||
__all__ = ["SimilarityRanker"]
|
__all__ = ["TransformersSimilarityRanker"]
|
||||||
|
@ -14,19 +14,20 @@ with LazyImport(message="Run 'pip install transformers[torch,sentencepiece]==4.3
|
|||||||
|
|
||||||
|
|
||||||
@component
|
@component
|
||||||
class SimilarityRanker:
|
class TransformersSimilarityRanker:
|
||||||
"""
|
"""
|
||||||
Ranks documents based on query similarity.
|
Ranks documents based on query similarity.
|
||||||
|
It uses a pre-trained cross-encoder model (from Hugging Face Hub) to embed the query and documents.
|
||||||
|
|
||||||
Usage example:
|
Usage example:
|
||||||
```
|
```
|
||||||
from haystack.preview import Document
|
from haystack.preview import Document
|
||||||
from haystack.preview.components.rankers import SimilarityRanker
|
from haystack.preview.components.rankers import TransformersSimilarityRanker
|
||||||
|
|
||||||
sampler = SimilarityRanker()
|
ranker = TransformersSimilarityRanker()
|
||||||
docs = [Document(text="Paris"), Document(text="Berlin")]
|
docs = [Document(text="Paris"), Document(text="Berlin")]
|
||||||
query = "City in Germany"
|
query = "City in Germany"
|
||||||
output = sampler.run(query=query, documents=docs)
|
output = ranker.run(query=query, documents=docs)
|
||||||
docs = output["documents"]
|
docs = output["documents"]
|
||||||
assert len(docs) == 2
|
assert len(docs) == 2
|
||||||
assert docs[0].text == "Berlin"
|
assert docs[0].text == "Berlin"
|
||||||
@ -41,9 +42,10 @@ class SimilarityRanker:
|
|||||||
top_k: int = 10,
|
top_k: int = 10,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an instance of SimilarityRanker.
|
Creates an instance of TransformersSimilarityRanker.
|
||||||
|
|
||||||
:param model_name_or_path: Path to a pre-trained sentence-transformers model.
|
:param model_name_or_path: The name or path of a pre-trained cross-encoder model
|
||||||
|
from Hugging Face Hub.
|
||||||
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
|
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
|
||||||
:param token: The API token used to download private models from Hugging Face.
|
:param token: The API token used to download private models from Hugging Face.
|
||||||
If this parameter is set to `True`, then the token generated when running
|
If this parameter is set to `True`, then the token generated when running
|
@ -0,0 +1,5 @@
|
|||||||
|
---
|
||||||
|
preview:
|
||||||
|
- |
|
||||||
|
Rename `SimilarityRanker` to `TransformersSimilarityRanker`,
|
||||||
|
as there will be more similarity rankers in the future.
|
@ -1,30 +1,32 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from haystack.preview import Document, ComponentError
|
from haystack.preview import Document, ComponentError
|
||||||
from haystack.preview.components.rankers.similarity import SimilarityRanker
|
from haystack.preview.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||||
|
|
||||||
|
|
||||||
class TestSimilarityRanker:
|
class TestSimilarityRanker:
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict(self):
|
def test_to_dict(self):
|
||||||
component = SimilarityRanker()
|
component = TransformersSimilarityRanker()
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "SimilarityRanker",
|
"type": "TransformersSimilarityRanker",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
"top_k": 10,
|
"top_k": 10,
|
||||||
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
|
||||||
"token": None,
|
"token": None,
|
||||||
|
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
def test_to_dict_with_custom_init_parameters(self):
|
def test_to_dict_with_custom_init_parameters(self):
|
||||||
component = SimilarityRanker(model_name_or_path="my_model", device="cuda", token="my_token", top_k=5)
|
component = TransformersSimilarityRanker(
|
||||||
|
model_name_or_path="my_model", device="cuda", token="my_token", top_k=5
|
||||||
|
)
|
||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "SimilarityRanker",
|
"type": "TransformersSimilarityRanker",
|
||||||
"init_parameters": {
|
"init_parameters": {
|
||||||
"device": "cuda",
|
"device": "cuda",
|
||||||
"model_name_or_path": "my_model",
|
"model_name_or_path": "my_model",
|
||||||
@ -46,7 +48,7 @@ class TestSimilarityRanker:
|
|||||||
"""
|
"""
|
||||||
Test if the component ranks documents correctly.
|
Test if the component ranks documents correctly.
|
||||||
"""
|
"""
|
||||||
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2")
|
||||||
ranker.warm_up()
|
ranker.warm_up()
|
||||||
docs_before = [Document(text=text) for text in docs_before_texts]
|
docs_before = [Document(text=text) for text in docs_before_texts]
|
||||||
output = ranker.run(query=query, documents=docs_before)
|
output = ranker.run(query=query, documents=docs_before)
|
||||||
@ -61,7 +63,7 @@ class TestSimilarityRanker:
|
|||||||
# Returns an empty list if no documents are provided
|
# Returns an empty list if no documents are provided
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_returns_empty_list_if_no_documents_are_provided(self):
|
def test_returns_empty_list_if_no_documents_are_provided(self):
|
||||||
sampler = SimilarityRanker()
|
sampler = TransformersSimilarityRanker()
|
||||||
sampler.warm_up()
|
sampler.warm_up()
|
||||||
output = sampler.run(query="City in Germany", documents=[])
|
output = sampler.run(query="City in Germany", documents=[])
|
||||||
assert output["documents"] == []
|
assert output["documents"] == []
|
||||||
@ -69,7 +71,7 @@ class TestSimilarityRanker:
|
|||||||
# Raises ComponentError if model is not warmed up
|
# Raises ComponentError if model is not warmed up
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_raises_component_error_if_model_not_warmed_up(self):
|
def test_raises_component_error_if_model_not_warmed_up(self):
|
||||||
sampler = SimilarityRanker()
|
sampler = TransformersSimilarityRanker()
|
||||||
|
|
||||||
with pytest.raises(ComponentError):
|
with pytest.raises(ComponentError):
|
||||||
sampler.run(query="query", documents=[Document(text="document")])
|
sampler.run(query="query", documents=[Document(text="document")])
|
||||||
@ -87,7 +89,7 @@ class TestSimilarityRanker:
|
|||||||
"""
|
"""
|
||||||
Test if the component ranks documents correctly with a custom top_k.
|
Test if the component ranks documents correctly with a custom top_k.
|
||||||
"""
|
"""
|
||||||
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
|
ranker = TransformersSimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
|
||||||
ranker.warm_up()
|
ranker.warm_up()
|
||||||
docs_before = [Document(text=text) for text in docs_before_texts]
|
docs_before = [Document(text=text) for text in docs_before_texts]
|
||||||
output = ranker.run(query=query, documents=docs_before)
|
output = ranker.run(query=query, documents=docs_before)
|
Loading…
x
Reference in New Issue
Block a user