mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-30 03:16:46 +00:00
Add top_k to SimilarityRanker (#6036)
This commit is contained in:
parent
4b8b6e9191
commit
d51be9edac
@ -1,6 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Union, Dict, Any
|
from typing import List, Union, Dict, Any, Optional
|
||||||
|
|
||||||
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
|
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
|
||||||
from haystack.preview.lazy_imports import LazyImport
|
from haystack.preview.lazy_imports import LazyImport
|
||||||
@ -34,17 +34,24 @@ class SimilarityRanker:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cpu"
|
self,
|
||||||
|
model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
top_k: int = 10,
|
||||||
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Creates an instance of SimilarityRanker.
|
Creates an instance of SimilarityRanker.
|
||||||
|
|
||||||
:param model_name_or_path: Path to a pre-trained sentence-transformers model.
|
:param model_name_or_path: Path to a pre-trained sentence-transformers model.
|
||||||
|
:param top_k: The maximum number of documents to return per query.
|
||||||
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
|
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
|
||||||
"""
|
"""
|
||||||
torch_and_transformers_import.check()
|
torch_and_transformers_import.check()
|
||||||
|
|
||||||
self.model_name_or_path = model_name_or_path
|
self.model_name_or_path = model_name_or_path
|
||||||
|
if top_k <= 0:
|
||||||
|
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||||
|
self.top_k = top_k
|
||||||
self.device = device
|
self.device = device
|
||||||
self.model = None
|
self.model = None
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
@ -63,7 +70,7 @@ class SimilarityRanker:
|
|||||||
"""
|
"""
|
||||||
Serialize this component to a dictionary.
|
Serialize this component to a dictionary.
|
||||||
"""
|
"""
|
||||||
return default_to_dict(self, device=self.device, model_name_or_path=self.model_name_or_path)
|
return default_to_dict(self, top_k=self.top_k, device=self.device, model_name_or_path=self.model_name_or_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
|
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
|
||||||
@ -73,17 +80,24 @@ class SimilarityRanker:
|
|||||||
return default_from_dict(cls, data)
|
return default_from_dict(cls, data)
|
||||||
|
|
||||||
@component.output_types(documents=List[Document])
|
@component.output_types(documents=List[Document])
|
||||||
def run(self, query: str, documents: List[Document]):
|
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
"""
|
"""
|
||||||
Returns a list of documents ranked by their similarity to the given query
|
Returns a list of documents ranked by their similarity to the given query
|
||||||
|
|
||||||
:param query: Query string.
|
:param query: Query string.
|
||||||
:param documents: List of Documents.
|
:param documents: List of Documents.
|
||||||
|
:param top_k: The maximum number of documents to return.
|
||||||
:return: List of Documents sorted by (desc.) similarity with the query.
|
:return: List of Documents sorted by (desc.) similarity with the query.
|
||||||
"""
|
"""
|
||||||
if not documents:
|
if not documents:
|
||||||
return {"documents": []}
|
return {"documents": []}
|
||||||
|
|
||||||
|
if top_k is None:
|
||||||
|
top_k = self.top_k
|
||||||
|
|
||||||
|
elif top_k <= 0:
|
||||||
|
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||||
|
|
||||||
# If a model path is provided but the model isn't loaded
|
# If a model path is provided but the model isn't loaded
|
||||||
if self.model_name_or_path and not self.model:
|
if self.model_name_or_path and not self.model:
|
||||||
raise ComponentError(
|
raise ComponentError(
|
||||||
@ -105,4 +119,4 @@ class SimilarityRanker:
|
|||||||
i = sorted_index_tensor.item()
|
i = sorted_index_tensor.item()
|
||||||
documents[i].score = similarity_scores[i].item()
|
documents[i].score = similarity_scores[i].item()
|
||||||
ranked_docs.append(documents[i])
|
ranked_docs.append(documents[i])
|
||||||
return {"documents": ranked_docs}
|
return {"documents": ranked_docs[:top_k]}
|
||||||
|
@ -11,7 +11,11 @@ class TestSimilarityRanker:
|
|||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "SimilarityRanker",
|
"type": "SimilarityRanker",
|
||||||
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
|
"init_parameters": {
|
||||||
|
"device": "cpu",
|
||||||
|
"top_k": 10,
|
||||||
|
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.unit
|
@pytest.mark.unit
|
||||||
@ -20,14 +24,22 @@ class TestSimilarityRanker:
|
|||||||
data = component.to_dict()
|
data = component.to_dict()
|
||||||
assert data == {
|
assert data == {
|
||||||
"type": "SimilarityRanker",
|
"type": "SimilarityRanker",
|
||||||
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
|
"init_parameters": {
|
||||||
|
"device": "cpu",
|
||||||
|
"top_k": 10,
|
||||||
|
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@pytest.mark.integration
|
@pytest.mark.integration
|
||||||
def test_from_dict(self):
|
def test_from_dict(self):
|
||||||
data = {
|
data = {
|
||||||
"type": "SimilarityRanker",
|
"type": "SimilarityRanker",
|
||||||
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"},
|
"init_parameters": {
|
||||||
|
"device": "cpu",
|
||||||
|
"top_k": 10,
|
||||||
|
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
component = SimilarityRanker.from_dict(data)
|
component = SimilarityRanker.from_dict(data)
|
||||||
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
@ -72,3 +84,28 @@ class TestSimilarityRanker:
|
|||||||
|
|
||||||
with pytest.raises(ComponentError):
|
with pytest.raises(ComponentError):
|
||||||
sampler.run(query="query", documents=[Document(text="document")])
|
sampler.run(query="query", documents=[Document(text="document")])
|
||||||
|
|
||||||
|
@pytest.mark.integration
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"query,docs_before_texts,expected_first_text",
|
||||||
|
[
|
||||||
|
("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"),
|
||||||
|
("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"),
|
||||||
|
("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_run_top_k(self, query, docs_before_texts, expected_first_text):
|
||||||
|
"""
|
||||||
|
Test if the component ranks documents correctly with a custom top_k.
|
||||||
|
"""
|
||||||
|
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
|
||||||
|
ranker.warm_up()
|
||||||
|
docs_before = [Document(text=text) for text in docs_before_texts]
|
||||||
|
output = ranker.run(query=query, documents=docs_before)
|
||||||
|
docs_after = output["documents"]
|
||||||
|
|
||||||
|
assert len(docs_after) == 2
|
||||||
|
assert docs_after[0].text == expected_first_text
|
||||||
|
|
||||||
|
sorted_scores = sorted([doc.score for doc in docs_after], reverse=True)
|
||||||
|
assert [doc.score for doc in docs_after] == sorted_scores
|
||||||
|
Loading…
x
Reference in New Issue
Block a user