feat: Add SentenceTransformersDiversityRanker (#7095)

* Add Diversity Ranker

* Update tests

* Add separate suffix, prefix params for query and documents; allow empty query

* Update docstrings

* Make changes based on review

* Add additional tests

* Add test for warm up

* Update release notes

---------

Co-authored-by: Sebastian Husch Lee <sjrl@users.noreply.github.com>
This commit is contained in:
Ashwin Mathur 2024-03-11 17:44:59 +05:30 committed by GitHub
parent 6239b60814
commit 38b3472bb2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 813 additions and 1 deletions

View File

@ -1,5 +1,11 @@
from haystack.components.rankers.lost_in_the_middle import LostInTheMiddleRanker
from haystack.components.rankers.meta_field import MetaFieldRanker
from haystack.components.rankers.sentence_transformers_diversity import SentenceTransformersDiversityRanker
from haystack.components.rankers.transformers_similarity import TransformersSimilarityRanker
__all__ = ["LostInTheMiddleRanker", "MetaFieldRanker", "TransformersSimilarityRanker"]
__all__ = [
"LostInTheMiddleRanker",
"MetaFieldRanker",
"SentenceTransformersDiversityRanker",
"TransformersSimilarityRanker",
]

View File

@ -0,0 +1,246 @@
from typing import Any, Dict, List, Literal, Optional
from haystack import ComponentError, Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
logger = logging.getLogger(__name__)
with LazyImport(message="Run 'pip install \"sentence-transformers>=2.2.0\"'") as torch_and_sentence_transformers_import:
import torch
from sentence_transformers import SentenceTransformer
@component
class SentenceTransformersDiversityRanker:
"""
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.
Usage example:
```python
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
ranker.warm_up()
docs = [Document(content="Paris"), Document(content="Berlin")]
query = "What is the capital of germany?"
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
"""
def __init__(
self,
model: str = "sentence-transformers/all-MiniLM-L6-v2",
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var("HF_API_TOKEN", strict=False),
similarity: Literal["dot_product", "cosine"] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
"""
Initialize a SentenceTransformersDiversityRanker.
:param model: Local path or name of the model in Hugging Face's model hub,
such as `'sentence-transformers/all-MiniLM-L6-v2'`.
:param top_k: The maximum number of Documents to return per query.
:param device: The device on which the model is loaded. If `None`, the default device is automatically
selected.
:param token: The API token used to download private models from Hugging Face.
:param similarity: Similarity metric for comparing embeddings. Can be set to "dot_product" (default) or
"cosine".
:param query_prefix: A string to add to the beginning of the query text before ranking.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and BGE.
:param query_suffix: A string to add to the end of the query text before ranking.
:param document_prefix: A string to add to the beginning of each Document text before ranking.
Can be used to prepend the text with an instruction, as required by some embedding models,
such as E5 and BGE.
:param document_suffix: A string to add to the end of each Document text before ranking.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
"""
torch_and_sentence_transformers_import.check()
self.model_name_or_path = model
if top_k is None or top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
self.top_k = top_k
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
self.similarity = similarity
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
def warm_up(self):
"""
Initializes the component.
"""
if self.model is None:
self.model = SentenceTransformer(
model_name_or_path=self.model_name_or_path,
device=self.device.to_torch_str(),
use_auth_token=self.token.resolve_value() if self.token else None,
)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
model=self.model_name_or_path,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
query_suffix=self.query_suffix,
document_suffix=self.document_suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersDiversityRanker":
"""
Deserializes the component from a dictionary.
:param data:
The dictionary to deserialize from.
:returns:
The deserialized component.
"""
serialized_device = data["init_parameters"]["device"]
data["init_parameters"]["device"] = ComponentDevice.from_dict(serialized_device)
deserialize_secrets_inplace(data["init_parameters"], keys=["token"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: List[Document]) -> List[str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = []
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key]
]
text_to_embed = (
self.document_prefix
+ self.embedding_separator.join(meta_values_to_embed + [doc.content or ""])
+ self.document_suffix
)
texts_to_embed.append(text_to_embed)
return texts_to_embed
def _greedy_diversity_order(self, query: str, documents: List[Document]) -> List[Document]:
"""
Orders the given list of documents to maximize diversity.
The algorithm first calculates embeddings for each document and the query. It starts by selecting the document
that is semantically closest to the query. Then, for each remaining document, it selects the one that, on
average, is least similar to the already selected documents. This process continues until all documents are
selected, resulting in a list where each subsequent document contributes the most to the overall diversity of
the selected set.
:param query: The search query.
:param documents: The list of Document objects to be ranked.
:return: A list of documents ordered to maximize diversity.
"""
texts_to_embed = self._prepare_texts_to_embed(documents)
# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
n = len(documents)
selected: List[int] = []
# Compute the similarity vector between the query and documents
query_doc_sim = query_embedding @ doc_embeddings.T
# Start with the document with the highest similarity to the query
selected.append(int(torch.argmax(query_doc_sim).item()))
selected_sum = doc_embeddings[selected[0]] / n
while len(selected) < n:
# Compute mean of dot products of all selected documents and all other documents
similarities = selected_sum @ doc_embeddings.T
# Mask documents that are already selected
similarities[selected] = torch.inf
# Select the document with the lowest total similarity score
index_unselected = int(torch.argmin(similarities).item())
selected.append(index_unselected)
# It's enough just to add to the selected vectors because dot product is distributive
# It's divided by n for numerical stability
selected_sum += doc_embeddings[index_unselected] / n
ranked_docs: List[Document] = [documents[i] for i in selected]
return ranked_docs
@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""
Rank the documents based on their diversity.
:param query: The search query.
:param documents: List of Document objects to be ranker.
:param top_k: Optional. An integer to override the top_k set during initialization.
:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the diversity ranking.
:raises ValueError: If the top_k value is less than or equal to 0.
"""
if not documents:
return {"documents": []}
if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
if self.model is None:
error_msg = (
"The component SentenceTransformersDiversityRanker wasn't warmed up. "
"Run 'warm_up()' before calling 'run()'."
)
raise ComponentError(error_msg)
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
return {"documents": diversity_sorted[:top_k]}

View File

@ -0,0 +1,6 @@
---
features:
- |
Add `SentenceTransformersDiversityRanker`.
The Diversity Ranker orders documents in such a way as to maximize the overall diversity of the given documents.
The ranker leverages sentence-transformer models to calculate semantic embeddings for each document and the query.

View File

@ -0,0 +1,554 @@
from unittest.mock import MagicMock, call, patch
import pytest
import torch
from haystack import ComponentError, Document
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
def mock_encode_response(texts, **kwargs):
if texts == ["city"]:
return torch.tensor([[1.0, 1.0]])
elif texts == ["Eiffel Tower", "Berlin", "Bananas"]:
return torch.tensor([[1.0, 0.0], [0.8, 0.8], [0.0, 1.0]])
else:
return torch.tensor([[0.0, 1.0]] * len(texts))
class TestSentenceTransformersDiversityRanker:
def test_init(self):
component = SentenceTransformersDiversityRanker()
assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert component.top_k == 10
assert component.device == ComponentDevice.resolve_device(None)
assert component.similarity == "cosine"
assert component.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert component.query_prefix == ""
assert component.document_prefix == ""
assert component.query_suffix == ""
assert component.document_suffix == ""
assert component.meta_fields_to_embed == []
assert component.embedding_separator == "\n"
def test_init_with_custom_init_parameters(self):
component = SentenceTransformersDiversityRanker(
model="sentence-transformers/msmarco-distilbert-base-v4",
top_k=5,
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_token("fake-api-token"),
similarity="dot_product",
query_prefix="query:",
document_prefix="document:",
query_suffix="query suffix",
document_suffix="document suffix",
meta_fields_to_embed=["meta_field"],
embedding_separator="--",
)
assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert component.top_k == 5
assert component.device == ComponentDevice.from_str("cuda:0")
assert component.similarity == "dot_product"
assert component.token == Secret.from_token("fake-api-token")
assert component.query_prefix == "query:"
assert component.document_prefix == "document:"
assert component.query_suffix == "query suffix"
assert component.document_suffix == "document suffix"
assert component.meta_fields_to_embed == ["meta_field"]
assert component.embedding_separator == "--"
def test_to_dict(self):
component = SentenceTransformersDiversityRanker()
data = component.to_dict()
assert data == {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"top_k": 10,
"device": ComponentDevice.resolve_device(None).to_dict(),
"similarity": "cosine",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "",
"document_prefix": "",
"query_suffix": "",
"document_suffix": "",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
def test_from_dict(self):
data = {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"top_k": 10,
"device": ComponentDevice.resolve_device(None).to_dict(),
"similarity": "cosine",
"token": {"env_vars": ["HF_API_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "",
"document_prefix": "",
"query_suffix": "",
"document_suffix": "",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
}
ranker = SentenceTransformersDiversityRanker.from_dict(data)
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == "cosine"
assert ranker.token == Secret.from_env_var("HF_API_TOKEN", strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
assert ranker.query_suffix == ""
assert ranker.document_suffix == ""
assert ranker.meta_fields_to_embed == []
assert ranker.embedding_separator == "\n"
def test_to_dict_with_custom_init_parameters(self):
component = SentenceTransformersDiversityRanker(
model="sentence-transformers/msmarco-distilbert-base-v4",
top_k=5,
device=ComponentDevice.from_str("cuda:0"),
token=Secret.from_env_var("ENV_VAR", strict=False),
similarity="dot_product",
query_prefix="query:",
document_prefix="document:",
query_suffix="query suffix",
document_suffix="document suffix",
meta_fields_to_embed=["meta_field"],
embedding_separator="--",
)
data = component.to_dict()
assert data == {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/msmarco-distilbert-base-v4",
"top_k": 5,
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"similarity": "dot_product",
"query_prefix": "query:",
"document_prefix": "document:",
"query_suffix": "query suffix",
"document_suffix": "document suffix",
"meta_fields_to_embed": ["meta_field"],
"embedding_separator": "--",
},
}
def test_from_dict_with_custom_init_parameters(self):
data = {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/msmarco-distilbert-base-v4",
"top_k": 5,
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"similarity": "dot_product",
"query_prefix": "query:",
"document_prefix": "document:",
"query_suffix": "query suffix",
"document_suffix": "document suffix",
"meta_fields_to_embed": ["meta_field"],
"embedding_separator": "--",
},
}
ranker = SentenceTransformersDiversityRanker.from_dict(data)
assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert ranker.top_k == 5
assert ranker.device == ComponentDevice.from_str("cuda:0")
assert ranker.similarity == "dot_product"
assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False)
assert ranker.query_prefix == "query:"
assert ranker.document_prefix == "document:"
assert ranker.query_suffix == "query suffix"
assert ranker.document_suffix == "document suffix"
assert ranker.meta_fields_to_embed == ["meta_field"]
assert ranker.embedding_separator == "--"
def test_run_incorrect_similarity(self):
"""
Tests that run method raises ValueError if similarity is incorrect
"""
similarity = "incorrect"
with pytest.raises(
ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}."
):
SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_without_warm_up(self, similarity):
"""
Tests that run method raises ComponentError if model is not warmed up
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", top_k=1, similarity=similarity
)
documents = [Document(content="doc1"), Document(content="doc2")]
error_msg = "The component SentenceTransformersDiversityRanker wasn't warmed up."
with pytest.raises(ComponentError, match=error_msg):
ranker.run(query="test query", documents=documents)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_warm_up(self, similarity):
"""
Test that ranker loads the SentenceTransformer model correctly during warm up.
"""
mock_model_class = MagicMock()
mock_model_instance = MagicMock()
mock_model_class.return_value = mock_model_instance
with patch(
"haystack.components.rankers.sentence_transformers_diversity.SentenceTransformer", new=mock_model_class
):
ranker = SentenceTransformersDiversityRanker(model="mock_model_name", similarity=similarity)
assert ranker.model is None
ranker.warm_up()
mock_model_class.assert_called_once_with(
model_name_or_path="mock_model_name",
device=ComponentDevice.resolve_device(None).to_torch_str(),
use_auth_token=None,
)
assert ranker.model == mock_model_instance
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_empty_query(self, similarity):
"""
Test that ranker can be run with an empty query.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", top_k=3, similarity=similarity
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
documents = [Document(content="doc1"), Document(content="doc2")]
result = ranker.run(query="", documents=documents)
ranked_docs = result["documents"]
assert isinstance(ranked_docs, list)
assert len(ranked_docs) == 2
assert all(isinstance(doc, Document) for doc in ranked_docs)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_top_k(self, similarity):
"""
Test that run method returns the correct number of documents for different top_k values passed at
initialization and runtime.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=3
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "test query"
documents = [
Document(content="doc1"),
Document(content="doc2"),
Document(content="doc3"),
Document(content="doc4"),
]
result = ranker.run(query=query, documents=documents)
ranked_docs = result["documents"]
assert isinstance(ranked_docs, list)
assert len(ranked_docs) == 3
assert all(isinstance(doc, Document) for doc in ranked_docs)
# Passing a different top_k at runtime
result = ranker.run(query=query, documents=documents, top_k=2)
ranked_docs = result["documents"]
assert isinstance(ranked_docs, list)
assert len(ranked_docs) == 2
assert all(isinstance(doc, Document) for doc in ranked_docs)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_negative_top_k_at_init(self, similarity):
"""
Tests that run method raises an error for negative top-k set at init.
"""
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=-5
)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_top_k_is_none_at_init(self, similarity):
"""
Tests that run method raises an error for top-k set to None at init.
"""
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=None
)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_negative_top_k(self, similarity):
"""
Tests that run method raises an error for negative top-k set at runtime.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=10
)
ranker.model = MagicMock()
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
ranker.run(query=query, documents=documents, top_k=-5)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_top_k_is_none(self, similarity):
"""
Tests that run method returns the correct order of documents for top-k set to None.
"""
# Setting top_k to None is ignored during runtime, it should use top_k set at init.
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=2
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.run(query=query, documents=documents, top_k=None)
assert len(result["documents"]) == 2
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_no_documents_provided(self, similarity):
"""
Test that run method returns an empty list if no documents are supplied.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.model = MagicMock()
query = "test query"
documents = []
results = ranker.run(query=query, documents=documents)
assert len(results["documents"]) == 0
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_with_less_documents_than_top_k(self, similarity):
"""
Tests that run method returns the correct number of documents for top_k values greater than number of documents.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, top_k=5
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
result = ranker.run(query=query, documents=documents)
assert len(result["documents"]) == 3
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_single_document_corner_case(self, similarity):
"""
Tests that run method returns the correct number of documents for a single document
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "test"
documents = [Document(content="doc1")]
result = ranker.run(query=query, documents=documents)
assert len(result["documents"]) == 1
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_prepare_texts_to_embed(self, similarity):
"""
Test creation of texts to embed from documents with meta fields, document prefix and suffix.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
similarity=similarity,
document_prefix="test doc: ",
document_suffix=" end doc.",
meta_fields_to_embed=["meta_field"],
embedding_separator="\n",
)
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
texts = ranker._prepare_texts_to_embed(documents=documents)
assert texts == [
"test doc: meta_value 0\ndocument number 0 end doc.",
"test doc: meta_value 1\ndocument number 1 end doc.",
"test doc: meta_value 2\ndocument number 2 end doc.",
"test doc: meta_value 3\ndocument number 3 end doc.",
"test doc: meta_value 4\ndocument number 4 end doc.",
]
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_encode_text(self, similarity):
"""
Test addition of suffix and prefix to the query and documents when creating embeddings.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2",
similarity=similarity,
query_prefix="test query: ",
query_suffix=" end query.",
document_prefix="test doc: ",
document_suffix=" end doc.",
meta_fields_to_embed=["meta_field"],
embedding_separator="\n",
)
query = "query"
documents = [Document(content=f"document number {i}", meta={"meta_field": f"meta_value {i}"}) for i in range(5)]
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
ranker.run(query=query, documents=documents)
assert ranker.model.encode.call_count == 2
ranker.model.assert_has_calls(
[
call.encode(
[
"test doc: meta_value 0\ndocument number 0 end doc.",
"test doc: meta_value 1\ndocument number 1 end doc.",
"test doc: meta_value 2\ndocument number 2 end doc.",
"test doc: meta_value 3\ndocument number 3 end doc.",
"test doc: meta_value 4\ndocument number 4 end doc.",
],
convert_to_tensor=True,
),
call.encode(["test query: query end query."], convert_to_tensor=True),
]
)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_greedy_diversity_order(self, similarity):
"""
Tests that the given list of documents is ordered to maximize diversity.
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
query = "city"
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
ranked_docs = ranker._greedy_diversity_order(query=query, documents=documents)
ranked_text = " ".join([doc.content for doc in ranked_docs])
assert ranked_text == "Berlin Eiffel Tower Bananas"
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run(self, similarity):
"""
Tests that run method returns documents in the correct order
"""
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.warm_up()
query = "city"
documents = [
Document(content="France"),
Document(content="Germany"),
Document(content="Eiffel Tower"),
Document(content="Berlin"),
Document(content="Bananas"),
Document(content="Silicon Valley"),
Document(content="Brandenburg Gate"),
]
result = ranker.run(query=query, documents=documents)
ranked_docs = result["documents"]
ranked_order = ", ".join([doc.content for doc in ranked_docs])
expected_order = "Berlin, Bananas, Eiffel Tower, Silicon Valley, France, Brandenburg Gate, Germany"
assert ranked_order == expected_order
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_real_world_use_case(self, similarity):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.warm_up()
query = "What are the reasons for long-standing animosities between Russia and Poland?"
doc1 = Document(
"One of the earliest known events in Russian-Polish history dates back to 981, when the Grand Prince of Kiev , "
"Vladimir Svyatoslavich , seized the Cherven Cities from the Duchy of Poland . The relationship between two by "
"that time was mostly close and cordial, as there had been no serious wars between both. In 966, Poland "
"accepted Christianity from Rome while Kievan Rus' —the ancestor of Russia, Ukraine and Belarus—was "
"Christianized by Constantinople. In 1054, the internal Christian divide formally split the Church into "
"the Catholic and Orthodox branches separating the Poles from the Eastern Slavs."
)
doc2 = Document(
"Since the fall of the Soviet Union , with Lithuania , Ukraine and Belarus regaining independence, the "
"Polish Russian border has mostly been replaced by borders with the respective countries, but there still "
"is a 210 km long border between Poland and the Kaliningrad Oblast"
)
doc3 = Document(
"As part of Poland's plans to become fully energy independent from Russia within the next years, Piotr "
"Wozniak, president of state-controlled oil and gas company PGNiG , stated in February 2019: 'The strategy of "
"the company is just to forget about Eastern suppliers and especially about Gazprom .'[53] In 2020, the "
"Stockholm Arbitrary Tribunal ruled that PGNiG's long-term contract gas price with Gazprom linked to oil prices "
"should be changed to approximate the Western European gas market price, backdated to 1 November 2014 when "
"PGNiG requested a price review under the contract. Gazprom had to refund about $1.5 billion to PGNiG."
)
doc4 = Document(
"Both Poland and Russia had accused each other for their historical revisionism . Russia has repeatedly "
"accused Poland for not honoring Soviet Red Army soldiers fallen in World War II for Poland, notably in "
"2017, in which Poland was thought on 'attempting to impose its own version of history' after Moscow was "
"not allowed to join an international effort to renovate a World War II museum at Sobibór , site of a "
"notorious Sobibor extermination camp."
)
doc5 = Document(
"President of Russia Vladimir Putin and Prime Minister of Poland Leszek Miller in 2002 Modern Polish Russian "
"relations begin with the fall of communism in1989 in Poland ( Solidarity and the Polish Round Table "
"Agreement ) and 1991 in Russia ( dissolution of the Soviet Union ). With a new democratic government after "
"the 1989 elections , Poland regained full sovereignty, [2] and what was the Soviet Union, became 15 newly "
"independent states , including the Russian Federation . Relations between modern Poland and Russia suffer "
"from constant ups and downs."
)
doc6 = Document(
"Soviet influence in Poland finally ended with the Round Table Agreement of 1989 guaranteeing free elections "
"in Poland, the Revolutions of 1989 against Soviet-sponsored Communist governments in the Eastern Block , and "
"finally the formal dissolution of the Warsaw Pact."
)
doc7 = Document(
"Dmitry Medvedev and then Polish Prime Minister Donald Tusk , 6 December 2010 BBC News reported that one of "
"the main effects of the 2010 Polish Air Force Tu-154 crash would be the impact it has on Russian-Polish "
"relations. [38] It was thought if the inquiry into the crash were not transparent, it would increase "
"suspicions toward Russia in Poland."
)
doc8 = Document(
"Soviet control over the Polish People's Republic lessened after Stalin's death and Gomulka's Thaw , and "
"ceased completely after the fall of the communist government in Poland in late 1989, although the "
"Soviet-Russian Northern Group of Forces did not leave Polish soil until 1993. The continuing Soviet military "
"presence allowed the Soviet Union to heavily influence Polish politics."
)
documents = [doc1, doc2, doc3, doc4, doc5, doc6, doc7, doc8]
result = ranker.run(query=query, documents=documents)
expected_order = [doc5, doc7, doc3, doc1, doc4, doc2, doc6, doc8]
expected_content = " ".join([doc.content or "" for doc in expected_order])
result_content = " ".join([doc.content or "" for doc in result["documents"]])
# Check the order of ranked documents by comparing the content of the ranked documents
assert result_content == expected_content