mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 13:38:53 +00:00
feat: Add SentenceTransformersDiversityRanker (#7095)
* Add Diversity Ranker * Update tests * Add separate suffix, prefix params for query and documents; allow empty query * Update docstrings * Make changes based on review * Add additional tests * Add test for warm up * Update release notes --------- Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
parent
6239b60814
commit
38b3472bb2
@ -1,5 +1,11 @@
|
||||
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
|
||||
from haystack.components.rankers.meta_field import MetaFieldRanker
|
||||
from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker
|
||||
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
|
||||
|
||||
__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
|
||||
__all__ = [
|
||||
"LostInTheMiddleRanker",
|
||||
"MetaFieldRanker",
|
||||
"SentenceTransformersDiversityRanker",
|
||||
"TransformersSimilarityRanker",
|
||||
]
|
||||
|
||||
246
haystack/components/rankers/sentence_transformers_diversity.py
Normal file
246
haystack/components/rankers/sentence_transformers_diversity.py
Normal file
@ -0,0 +1,246 @@
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_sentence_transformers_import:
|
||||
import torch
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
@component
|
||||
class SentenceTransformersDiversityRanker:
|
||||
"""
|
||||
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
|
||||
of the documents.
|
||||
|
||||
This component provides functionality to rank a list of documents based on their similarity with respect to the
|
||||
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
|
||||
the Documents.
|
||||
|
||||
Usage example:
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||
|
||||
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
|
||||
ranker.warm_up()
|
||||
|
||||
docs = [Document(content="Paris"), Document(content="Berlin")]
|
||||
query = "What is the capital of germany?"
|
||||
output = ranker.run(query=query, documents=docs)
|
||||
docs = output["documents"]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
top_k: int = 10,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
|
||||
similarity: Literal["dot_product", "cosine"] = "cosine",
|
||||
query_prefix: str = "",
|
||||
query_suffix: str = "",
|
||||
document_prefix: str = "",
|
||||
document_suffix: str = "",
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
"""
|
||||
Initialize a SentenceTransformersDiversityRanker.
|
||||
|
||||
:param model: Local path or name of the model in Hugging Face's model hub,
|
||||
such as `'sentence-transformers/all-MiniLM-L6-v2'`.
|
||||
:param top_k: The maximum number of Documents to return per query.
|
||||
:param device: The device on which the model is loaded. If `None`, the default device is automatically
|
||||
selected.
|
||||
:param token: The API token used to download private models from Hugging Face.
|
||||
:param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or
|
||||
"cosine".
|
||||
:param query_prefix: A string to add to the beginning of the query text before ranking.
|
||||
Can be used to prepend the text with an instruction, as required by some embedding models,
|
||||
such as E5 and BGE.
|
||||
:param query_suffix: A string to add to the end of the query text before ranking.
|
||||
:param document_prefix: A string to add to the beginning of each Document text before ranking.
|
||||
Can be used to prepend the text with an instruction, as required by some embedding models,
|
||||
such as E5 and BGE.
|
||||
:param document_suffix: A string to add to the end of each Document text before ranking.
|
||||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
|
||||
"""
|
||||
torch_and_sentence_transformers_import.check()
|
||||
|
||||
self.model_name_or_path = model
|
||||
if top_k is None or top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
self.top_k = top_k
|
||||
self.device = ComponentDevice.resolve_device(device)
|
||||
self.token = token
|
||||
self.model = None
|
||||
if similarity not in ["dot_product", "cosine"]:
|
||||
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
|
||||
self.similarity = similarity
|
||||
self.query_prefix = query_prefix
|
||||
self.document_prefix = document_prefix
|
||||
self.query_suffix = query_suffix
|
||||
self.document_suffix = document_suffix
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
Initializes the component.
|
||||
"""
|
||||
if self.model is None:
|
||||
self.model = SentenceTransformer(
|
||||
model_name_or_path=self.model_name_or_path,
|
||||
device=self.device.to_torch_str(),
|
||||
use_auth_token=self.token.resolve_value() if self.token else None,
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes the component to a dictionary.
|
||||
|
||||
:returns:
|
||||
Dictionary with serialized data.
|
||||
"""
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model_name_or_path,
|
||||
device=self.device.to_dict(),
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
top_k=self.top_k,
|
||||
similarity=self.similarity,
|
||||
query_prefix=self.query_prefix,
|
||||
document_prefix=self.document_prefix,
|
||||
query_suffix=self.query_suffix,
|
||||
document_suffix=self.document_suffix,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
|
||||
"""
|
||||
Deserializes the component from a dictionary.
|
||||
|
||||
:param data:
|
||||
The dictionary to deserialize from.
|
||||
:returns:
|
||||
The deserialized component.
|
||||
"""
|
||||
serialized_device = data["init_parameters"]["device"]
|
||||
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
|
||||
|
||||
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
|
||||
"""
|
||||
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
|
||||
"""
|
||||
texts_to_embed = []
|
||||
for doc in documents:
|
||||
meta_values_to_embed = [
|
||||
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
|
||||
]
|
||||
text_to_embed = (
|
||||
self.document_prefix
|
||||
+ self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
|
||||
+ self.document_suffix
|
||||
)
|
||||
texts_to_embed.append(text_to_embed)
|
||||
|
||||
return texts_to_embed
|
||||
|
||||
def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
"""
|
||||
Orders the given list of documents to maximize diversity.
|
||||
|
||||
The algorithm first calculates embeddings for each document and the query. It starts by selecting the document
|
||||
that is semantically closest to the query. Then, for each remaining document, it selects the one that, on
|
||||
average, is least similar to the already selected documents. This process continues until all documents are
|
||||
selected, resulting in a list where each subsequent document contributes the most to the overall diversity of
|
||||
the selected set.
|
||||
|
||||
:param query: The search query.
|
||||
:param documents: The list of Document objects to be ranked.
|
||||
|
||||
:return: A list of documents ordered to maximize diversity.
|
||||
"""
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents)
|
||||
|
||||
# Calculate embeddings
|
||||
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
|
||||
# Normalize embeddings to unit length for computing cosine similarity
|
||||
if self.similarity == "cosine":
|
||||
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
||||
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
||||
|
||||
n = len(documents)
|
||||
selected: List[int] = []
|
||||
|
||||
# Compute the similarity vector between the query and documents
|
||||
query_doc_sim = query_embedding @ doc_embeddings.T
|
||||
|
||||
# Start with the document with the highest similarity to the query
|
||||
selected.append(int(torch.argmax(query_doc_sim).item()))
|
||||
|
||||
selected_sum = doc_embeddings[selected[0]] / n
|
||||
|
||||
while len(selected) < n:
|
||||
# Compute mean of dot products of all selected documents and all other documents
|
||||
similarities = selected_sum @ doc_embeddings.T
|
||||
# Mask documents that are already selected
|
||||
similarities[selected] = torch.inf
|
||||
# Select the document with the lowest total similarity score
|
||||
index_unselected = int(torch.argmin(similarities).item())
|
||||
selected.append(index_unselected)
|
||||
# It's enough just to add to the selected vectors because dot product is distributive
|
||||
# It's divided by n for numerical stability
|
||||
selected_sum += doc_embeddings[index_unselected] / n
|
||||
|
||||
ranked_docs: List[Document] = [documents[i] for i in selected]
|
||||
|
||||
return ranked_docs
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
"""
|
||||
Rank the documents based on their diversity.
|
||||
|
||||
:param query: The search query.
|
||||
:param documents: List of Document objects to be ranker.
|
||||
:param top_k: Optional. An integer to override the top_k set during initialization.
|
||||
|
||||
:returns: A dictionary with the following key:
|
||||
- `documents`: List of Document objects that have been selected based on the diversity ranking.
|
||||
|
||||
:raises ValueError: If the top_k value is less than or equal to 0.
|
||||
"""
|
||||
if not documents:
|
||||
return {"documents": []}
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
elif top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
if self.model is None:
|
||||
error_msg = (
|
||||
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
|
||||
"Run 'warm_up()' before calling 'run()'."
|
||||
)
|
||||
raise ComponentError(error_msg)
|
||||
|
||||
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
|
||||
|
||||
return {"documents": diversity_sorted[:top_k]}
|
||||
@ -0,0 +1,6 @@
|
||||
---
|
||||
features:
|
||||
- |
|
||||
Add `SentenceTransformersDiversityRanker`.
|
||||
The Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents.
|
||||
The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query.
|
||||
554
test/components/rankers/test_sentence_transformers_diversity.py
Normal file
554
test/components/rankers/test_sentence_transformers_diversity.py
Normal file
@ -0,0 +1,554 @@
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import ComponentError, Document
|
||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
|
||||
def mock_encode_response(texts, **kwargs):
|
||||
if texts == ["city"]:
|
||||
return torch.tensor([[1.0, 1.0]])
|
||||
elif texts == ["Eiffel Tower", "Berlin", "Bananas"]:
|
||||
return torch.tensor([[1.0, 0.0], [0.8, 0.8], [0.0, 1.0]])
|
||||
else:
|
||||
return torch.tensor([[0.0, 1.0]] * len(texts))
|
||||
|
||||
|
||||
class TestSentenceTransformersDiversityRanker:
|
||||
def test_init(self):
|
||||
component = SentenceTransformersDiversityRanker()
|
||||
assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert component.top_k == 10
|
||||
assert component.device == ComponentDevice.resolve_device(None)
|
||||
assert component.similarity == "cosine"
|
||||
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert component.query_prefix == ""
|
||||
assert component.document_prefix == ""
|
||||
assert component.query_suffix == ""
|
||||
assert component.document_suffix == ""
|
||||
assert component.meta_fields_to_embed == []
|
||||
assert component.embedding_separator == "\n"
|
||||
|
||||
def test_init_with_custom_init_parameters(self):
|
||||
component = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||
top_k=5,
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token=Secret.from_token("fake-api-token"),
|
||||
similarity="dot_product",
|
||||
query_prefix="query:",
|
||||
document_prefix="document:",
|
||||
query_suffix="query suffix",
|
||||
document_suffix="document suffix",
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator="--",
|
||||
)
|
||||
assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
|
||||
assert component.top_k == 5
|
||||
assert component.device == ComponentDevice.from_str("cuda:0")
|
||||
assert component.similarity == "dot_product"
|
||||
assert component.token == Secret.from_token("fake-api-token")
|
||||
assert component.query_prefix == "query:"
|
||||
assert component.document_prefix == "document:"
|
||||
assert component.query_suffix == "query suffix"
|
||||
assert component.document_suffix == "document suffix"
|
||||
assert component.meta_fields_to_embed == ["meta_field"]
|
||||
assert component.embedding_separator == "--"
|
||||
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersDiversityRanker()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"top_k": 10,
|
||||
"device": ComponentDevice.resolve_device(None).to_dict(),
|
||||
"similarity": "cosine",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"query_prefix": "",
|
||||
"document_prefix": "",
|
||||
"query_suffix": "",
|
||||
"document_suffix": "",
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"top_k": 10,
|
||||
"device": ComponentDevice.resolve_device(None).to_dict(),
|
||||
"similarity": "cosine",
|
||||
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"query_prefix": "",
|
||||
"document_prefix": "",
|
||||
"query_suffix": "",
|
||||
"document_suffix": "",
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
}
|
||||
ranker = SentenceTransformersDiversityRanker.from_dict(data)
|
||||
|
||||
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||
assert ranker.similarity == "cosine"
|
||||
assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
|
||||
assert ranker.query_prefix == ""
|
||||
assert ranker.document_prefix == ""
|
||||
assert ranker.query_suffix == ""
|
||||
assert ranker.document_suffix == ""
|
||||
assert ranker.meta_fields_to_embed == []
|
||||
assert ranker.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
component = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||
top_k=5,
|
||||
device=ComponentDevice.from_str("cuda:0"),
|
||||
token=Secret.from_env_var("ENV_VAR", strict=False),
|
||||
similarity="dot_product",
|
||||
query_prefix="query:",
|
||||
document_prefix="document:",
|
||||
query_suffix="query suffix",
|
||||
document_suffix="document suffix",
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator="--",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/msmarco-distilbert-base-v4",
|
||||
"top_k": 5,
|
||||
"device": ComponentDevice.from_str("cuda:0").to_dict(),
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"similarity": "dot_product",
|
||||
"query_prefix": "query:",
|
||||
"document_prefix": "document:",
|
||||
"query_suffix": "query suffix",
|
||||
"document_suffix": "document suffix",
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"embedding_separator": "--",
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self):
|
||||
data = {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/msmarco-distilbert-base-v4",
|
||||
"top_k": 5,
|
||||
"device": ComponentDevice.from_str("cuda:0").to_dict(),
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"similarity": "dot_product",
|
||||
"query_prefix": "query:",
|
||||
"document_prefix": "document:",
|
||||
"query_suffix": "query suffix",
|
||||
"document_suffix": "document suffix",
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"embedding_separator": "--",
|
||||
},
|
||||
}
|
||||
ranker = SentenceTransformersDiversityRanker.from_dict(data)
|
||||
|
||||
assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
|
||||
assert ranker.top_k == 5
|
||||
assert ranker.device == ComponentDevice.from_str("cuda:0")
|
||||
assert ranker.similarity == "dot_product"
|
||||
assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert ranker.query_prefix == "query:"
|
||||
assert ranker.document_prefix == "document:"
|
||||
assert ranker.query_suffix == "query suffix"
|
||||
assert ranker.document_suffix == "document suffix"
|
||||
assert ranker.meta_fields_to_embed == ["meta_field"]
|
||||
assert ranker.embedding_separator == "--"
|
||||
|
||||
def test_run_incorrect_similarity(self):
|
||||
"""
|
||||
Tests that run method raises ValueError if similarity is incorrect
|
||||
"""
|
||||
similarity = "incorrect"
|
||||
with pytest.raises(
|
||||
ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}."
|
||||
):
|
||||
SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_without_warm_up(self, similarity):
|
||||
"""
|
||||
Tests that run method raises ComponentError if model is not warmed up
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity
|
||||
)
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
|
||||
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
|
||||
with pytest.raises(ComponentError, match=error_msg):
|
||||
ranker.run(query="test query", documents=documents)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_warm_up(self, similarity):
|
||||
"""
|
||||
Test that ranker loads the SentenceTransformer model correctly during warm up.
|
||||
"""
|
||||
mock_model_class = MagicMock()
|
||||
mock_model_instance = MagicMock()
|
||||
mock_model_class.return_value = mock_model_instance
|
||||
|
||||
with patch(
|
||||
"haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer", new=mock_model_class
|
||||
):
|
||||
ranker = SentenceTransformersDiversityRanker(model="mock_model_name", similarity=similarity)
|
||||
|
||||
assert ranker.model is None
|
||||
|
||||
ranker.warm_up()
|
||||
|
||||
mock_model_class.assert_called_once_with(
|
||||
model_name_or_path="mock_model_name",
|
||||
device=ComponentDevice.resolve_device(None).to_torch_str(),
|
||||
use_auth_token=None,
|
||||
)
|
||||
assert ranker.model == mock_model_instance
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_empty_query(self, similarity):
|
||||
"""
|
||||
Test that ranker can be run with an empty query.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
documents = [Document(content="doc1"), Document(content="doc2")]
|
||||
|
||||
result = ranker.run(query="", documents=documents)
|
||||
ranked_docs = result["documents"]
|
||||
|
||||
assert isinstance(ranked_docs, list)
|
||||
assert len(ranked_docs) == 2
|
||||
assert all(isinstance(doc, Document) for doc in ranked_docs)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_top_k(self, similarity):
|
||||
"""
|
||||
Test that run method returns the correct number of documents for different top_k values passed at
|
||||
initialization and runtime.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
query = "test query"
|
||||
documents = [
|
||||
Document(content="doc1"),
|
||||
Document(content="doc2"),
|
||||
Document(content="doc3"),
|
||||
Document(content="doc4"),
|
||||
]
|
||||
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
ranked_docs = result["documents"]
|
||||
|
||||
assert isinstance(ranked_docs, list)
|
||||
assert len(ranked_docs) == 3
|
||||
assert all(isinstance(doc, Document) for doc in ranked_docs)
|
||||
|
||||
# Passing a different top_k at runtime
|
||||
result = ranker.run(query=query, documents=documents, top_k=2)
|
||||
ranked_docs = result["documents"]
|
||||
|
||||
assert isinstance(ranked_docs, list)
|
||||
assert len(ranked_docs) == 2
|
||||
assert all(isinstance(doc, Document) for doc in ranked_docs)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_negative_top_k_at_init(self, similarity):
|
||||
"""
|
||||
Tests that run method raises an error for negative top-k set at init.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
|
||||
SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_top_k_is_none_at_init(self, similarity):
|
||||
"""
|
||||
Tests that run method raises an error for top-k set to None at init.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
|
||||
SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_negative_top_k(self, similarity):
|
||||
"""
|
||||
Tests that run method raises an error for negative top-k set at runtime.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
query = "test"
|
||||
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
|
||||
ranker.run(query=query, documents=documents, top_k=-5)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_top_k_is_none(self, similarity):
|
||||
"""
|
||||
Tests that run method returns the correct order of documents for top-k set to None.
|
||||
"""
|
||||
# Setting top_k to None is ignored during runtime, it should use top_k set at init.
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
query = "test"
|
||||
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
|
||||
result = ranker.run(query=query, documents=documents, top_k=None)
|
||||
|
||||
assert len(result["documents"]) == 2
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_no_documents_provided(self, similarity):
|
||||
"""
|
||||
Test that run method returns an empty list if no documents are supplied.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
query = "test query"
|
||||
documents = []
|
||||
results = ranker.run(query=query, documents=documents)
|
||||
|
||||
assert len(results["documents"]) == 0
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_with_less_documents_than_top_k(self, similarity):
|
||||
"""
|
||||
Tests that run method returns the correct number of documents for top_k values greater than number of documents.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
query = "test"
|
||||
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
|
||||
assert len(result["documents"]) == 3
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_single_document_corner_case(self, similarity):
|
||||
"""
|
||||
Tests that run method returns the correct number of documents for a single document
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
query = "test"
|
||||
documents = [Document(content="doc1")]
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
|
||||
assert len(result["documents"]) == 1
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_prepare_texts_to_embed(self, similarity):
|
||||
"""
|
||||
Test creation of texts to embed from documents with meta fields, document prefix and suffix.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
similarity=similarity,
|
||||
document_prefix="test doc: ",
|
||||
document_suffix=" end doc.",
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator="\n",
|
||||
)
|
||||
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
|
||||
texts = ranker._prepare_texts_to_embed(documents=documents)
|
||||
|
||||
assert texts == [
|
||||
"test doc: meta_value 0\ndocument number 0 end doc.",
|
||||
"test doc: meta_value 1\ndocument number 1 end doc.",
|
||||
"test doc: meta_value 2\ndocument number 2 end doc.",
|
||||
"test doc: meta_value 3\ndocument number 3 end doc.",
|
||||
"test doc: meta_value 4\ndocument number 4 end doc.",
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_encode_text(self, similarity):
|
||||
"""
|
||||
Test addition of suffix and prefix to the query and documents when creating embeddings.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2",
|
||||
similarity=similarity,
|
||||
query_prefix="test query: ",
|
||||
query_suffix=" end query.",
|
||||
document_prefix="test doc: ",
|
||||
document_suffix=" end doc.",
|
||||
meta_fields_to_embed=["meta_field"],
|
||||
embedding_separator="\n",
|
||||
)
|
||||
query = "query"
|
||||
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
ranker.run(query=query, documents=documents)
|
||||
|
||||
assert ranker.model.encode.call_count == 2
|
||||
ranker.model.assert_has_calls(
|
||||
[
|
||||
call.encode(
|
||||
[
|
||||
"test doc: meta_value 0\ndocument number 0 end doc.",
|
||||
"test doc: meta_value 1\ndocument number 1 end doc.",
|
||||
"test doc: meta_value 2\ndocument number 2 end doc.",
|
||||
"test doc: meta_value 3\ndocument number 3 end doc.",
|
||||
"test doc: meta_value 4\ndocument number 4 end doc.",
|
||||
],
|
||||
convert_to_tensor=True,
|
||||
),
|
||||
call.encode(["test query: query end query."], convert_to_tensor=True),
|
||||
]
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_greedy_diversity_order(self, similarity):
|
||||
"""
|
||||
Tests that the given list of documents is ordered to maximize diversity.
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
query = "city"
|
||||
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
|
||||
ranked_docs = ranker._greedy_diversity_order(query=query, documents=documents)
|
||||
ranked_text = " ".join([doc.content for doc in ranked_docs])
|
||||
|
||||
assert ranked_text == "Berlin Eiffel Tower Bananas"
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run(self, similarity):
|
||||
"""
|
||||
Tests that run method returns documents in the correct order
|
||||
"""
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.warm_up()
|
||||
query = "city"
|
||||
documents = [
|
||||
Document(content="France"),
|
||||
Document(content="Germany"),
|
||||
Document(content="Eiffel Tower"),
|
||||
Document(content="Berlin"),
|
||||
Document(content="Bananas"),
|
||||
Document(content="Silicon Valley"),
|
||||
Document(content="Brandenburg Gate"),
|
||||
]
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
ranked_docs = result["documents"]
|
||||
ranked_order = ", ".join([doc.content for doc in ranked_docs])
|
||||
expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
|
||||
|
||||
assert ranked_order == expected_order
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_real_world_use_case(self, similarity):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.warm_up()
|
||||
query = "What are the reasons for long-standing animosities between Russia and Poland?"
|
||||
|
||||
doc1 = Document(
|
||||
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
|
||||
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
|
||||
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
|
||||
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
|
||||
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
|
||||
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
|
||||
)
|
||||
doc2 = Document(
|
||||
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
|
||||
"Polish Russian border has mostly been replaced by borders with the respective countries, but there still "
|
||||
"is a 210 km long border between Poland and the Kaliningrad Oblast"
|
||||
)
|
||||
doc3 = Document(
|
||||
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
|
||||
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
|
||||
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
|
||||
"Stockholm Arbitrary Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
|
||||
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
|
||||
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
|
||||
)
|
||||
doc4 = Document(
|
||||
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
|
||||
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
|
||||
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
|
||||
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
|
||||
"notorious Sobibor extermination camp."
|
||||
)
|
||||
doc5 = Document(
|
||||
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish Russian "
|
||||
"relations begin with the fall of communism in1989 in Poland ( Solidarity and the Polish Round Table "
|
||||
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
|
||||
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
|
||||
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
|
||||
"from constant ups and downs."
|
||||
)
|
||||
doc6 = Document(
|
||||
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
|
||||
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Block , and "
|
||||
"finally the formal dissolution of the Warsaw Pact."
|
||||
)
|
||||
doc7 = Document(
|
||||
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
|
||||
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
|
||||
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
|
||||
"suspicions toward Russia in Poland."
|
||||
)
|
||||
doc8 = Document(
|
||||
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomulka's Thaw , and "
|
||||
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
|
||||
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
|
||||
"presence allowed the Soviet Union to heavily influence Polish politics."
|
||||
)
|
||||
|
||||
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
|
||||
result = ranker.run(query=query, documents=documents)
|
||||
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
|
||||
expected_content = " ".join([doc.content or "" for doc in expected_order])
|
||||
result_content = " ".join([doc.content or "" for doc in result["documents"]])
|
||||
|
||||
# Check the order of ranked documents by comparing the content of the ranked documents
|
||||
assert result_content == expected_content
|
||||
Loading…
x
Reference in New Issue
Block a user