Add top_k to SimilarityRanker (#6036)

This commit is contained in:
Vladimir Blagojevic 2023-10-12 13:52:01 +02:00 committed by GitHub
parent 4b8b6e9191
commit d51be9edac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 8 deletions

View File

@ -1,6 +1,6 @@
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Union, Dict, Any from typing import List, Union, Dict, Any, Optional
from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict from haystack.preview import ComponentError, Document, component, default_from_dict, default_to_dict
from haystack.preview.lazy_imports import LazyImport from haystack.preview.lazy_imports import LazyImport
@ -34,17 +34,24 @@ class SimilarityRanker:
""" """
def __init__( def __init__(
self, model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2", device: str = "cpu" self,
model_name_or_path: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
top_k: int = 10,
device: str = "cpu",
): ):
""" """
Creates an instance of SimilarityRanker. Creates an instance of SimilarityRanker.
:param model_name_or_path: Path to a pre-trained sentence-transformers model. :param model_name_or_path: Path to a pre-trained sentence-transformers model.
:param top_k: The maximum number of documents to return per query.
:param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device. :param device: torch device (for example, cuda:0, cpu, mps) to limit model inference to a specific device.
""" """
torch_and_transformers_import.check() torch_and_transformers_import.check()
self.model_name_or_path = model_name_or_path self.model_name_or_path = model_name_or_path
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.top_k = top_k
self.device = device self.device = device
self.model = None self.model = None
self.tokenizer = None self.tokenizer = None
@ -63,7 +70,7 @@ class SimilarityRanker:
""" """
Serialize this component to a dictionary. Serialize this component to a dictionary.
""" """
return default_to_dict(self, device=self.device, model_name_or_path=self.model_name_or_path) return default_to_dict(self, top_k=self.top_k, device=self.device, model_name_or_path=self.model_name_or_path)
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker": def from_dict(cls, data: Dict[str, Any]) -> "SimilarityRanker":
@ -73,17 +80,24 @@ class SimilarityRanker:
return default_from_dict(cls, data) return default_from_dict(cls, data)
@component.output_types(documents=List[Document]) @component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document]): def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
""" """
Returns a list of documents ranked by their similarity to the given query Returns a list of documents ranked by their similarity to the given query
:param query: Query string. :param query: Query string.
:param documents: List of Documents. :param documents: List of Documents.
:param top_k: The maximum number of documents to return.
:return: List of Documents sorted by (desc.) similarity with the query. :return: List of Documents sorted by (desc.) similarity with the query.
""" """
if not documents: if not documents:
return {"documents": []} return {"documents": []}
if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
# If a model path is provided but the model isn't loaded # If a model path is provided but the model isn't loaded
if self.model_name_or_path and not self.model: if self.model_name_or_path and not self.model:
raise ComponentError( raise ComponentError(
@ -105,4 +119,4 @@ class SimilarityRanker:
i = sorted_index_tensor.item() i = sorted_index_tensor.item()
documents[i].score = similarity_scores[i].item() documents[i].score = similarity_scores[i].item()
ranked_docs.append(documents[i]) ranked_docs.append(documents[i])
return {"documents": ranked_docs} return {"documents": ranked_docs[:top_k]}

View File

@ -11,7 +11,11 @@ class TestSimilarityRanker:
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "SimilarityRanker", "type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, "init_parameters": {
"device": "cpu",
"top_k": 10,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
},
} }
@pytest.mark.unit @pytest.mark.unit
@ -20,14 +24,22 @@ class TestSimilarityRanker:
data = component.to_dict() data = component.to_dict()
assert data == { assert data == {
"type": "SimilarityRanker", "type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, "init_parameters": {
"device": "cpu",
"top_k": 10,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
},
} }
@pytest.mark.integration @pytest.mark.integration
def test_from_dict(self): def test_from_dict(self):
data = { data = {
"type": "SimilarityRanker", "type": "SimilarityRanker",
"init_parameters": {"device": "cpu", "model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2"}, "init_parameters": {
"device": "cpu",
"top_k": 10,
"model_name_or_path": "cross-encoder/ms-marco-MiniLM-L-6-v2",
},
} }
component = SimilarityRanker.from_dict(data) component = SimilarityRanker.from_dict(data)
assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2" assert component.model_name_or_path == "cross-encoder/ms-marco-MiniLM-L-6-v2"
@ -72,3 +84,28 @@ class TestSimilarityRanker:
with pytest.raises(ComponentError): with pytest.raises(ComponentError):
sampler.run(query="query", documents=[Document(text="document")]) sampler.run(query="query", documents=[Document(text="document")])
@pytest.mark.integration
@pytest.mark.parametrize(
"query,docs_before_texts,expected_first_text",
[
("City in Bosnia and Herzegovina", ["Berlin", "Belgrade", "Sarajevo"], "Sarajevo"),
("Machine learning", ["Python", "Bakery in Paris", "Tesla Giga Berlin"], "Python"),
("Cubist movement", ["Nirvana", "Pablo Picasso", "Coffee"], "Pablo Picasso"),
],
)
def test_run_top_k(self, query, docs_before_texts, expected_first_text):
"""
Test if the component ranks documents correctly with a custom top_k.
"""
ranker = SimilarityRanker(model_name_or_path="cross-encoder/ms-marco-MiniLM-L-6-v2", top_k=2)
ranker.warm_up()
docs_before = [Document(text=text) for text in docs_before_texts]
output = ranker.run(query=query, documents=docs_before)
docs_after = output["documents"]
assert len(docs_after) == 2
assert docs_after[0].text == expected_first_text
sorted_scores = sorted([doc.score for doc in docs_after], reverse=True)
assert [doc.score for doc in docs_after] == sorted_scores