feat: adding Maximum Margin Relevance Ranker (#8554)

* initial import

* linting

* adding MRR tests

* adding release notes

* fixing tests

* adding linting ignore to cross-encoder ranker

* update docstring

* refactoring

* making strategy Optional instead of Literal

* wip: adding unit tests

* refactoring MMR algorithm

* refactoring tests

* cleaning up and updating tests

* adding empty line between license + code

* bug in tests

* using Enum for strategy and similarity metric

* adding more tests

* adding empty line between license + code

* removing run time params

* PR comments

* PR comments

* fixing

* fixing serialisation

* fixing serialisation tests

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Update haystack/components/rankers/sentence_transformers_diversity.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* fixing tests

* PR comments

* PR comments

* PR comments

* PR comments

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
David S. Batista 2024-11-22 15:58:45 +01:00 committed by GitHub
parent a8eeb2024f
commit b5a2fad642
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 340 additions and 77 deletions

View File

@ -2,7 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Literal, Optional
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.lazy_imports import LazyImport
@ -16,24 +17,91 @@ with LazyImport(message="Run 'pip install \"sentence-transformers>=3.0.0\"'") as
from sentence_transformers import SentenceTransformer
class DiversityRankingStrategy(Enum):
"""
The strategy to use for diversity ranking.
"""
GREEDY_DIVERSITY_ORDER = "greedy_diversity_order"
MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance"
def __str__(self) -> str:
"""
Convert a Strategy enum to a string.
"""
return self.value
@staticmethod
def from_str(string: str) -> "DiversityRankingStrategy":
"""
Convert a string to a Strategy enum.
"""
enum_map = {e.value: e for e in DiversityRankingStrategy}
strategy = enum_map.get(string)
if strategy is None:
msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}"
raise ValueError(msg)
return strategy
class DiversityRankingSimilarity(Enum):
"""
The similarity metric to use for comparing embeddings.
"""
DOT_PRODUCT = "dot_product"
COSINE = "cosine"
def __str__(self) -> str:
"""
Convert a Similarity enum to a string.
"""
return self.value
@staticmethod
def from_str(string: str) -> "DiversityRankingSimilarity":
"""
Convert a string to a Similarity enum.
"""
enum_map = {e.value: e for e in DiversityRankingSimilarity}
similarity = enum_map.get(string)
if similarity is None:
msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}"
raise ValueError(msg)
return similarity
@component
class SentenceTransformersDiversityRanker:
"""
A Diversity Ranker based on Sentence Transformers.
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
of the documents.
Applies a document ranking algorithm based on one of the two strategies:
This component provides functionality to rank a list of documents based on their similarity with respect to the
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
the Documents.
1. Greedy Diversity Order:
Usage example:
Implements a document ranking algorithm that orders documents in a way that maximizes the overall diversity
of the documents based on their similarity to the query.
It uses a pre-trained Sentence Transformers model to embed the query and
the documents.
2. Maximum Margin Relevance:
Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR)
scores.
MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between
relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the
trade-off between relevance and diversity.
### Usage example
```python
from haystack import Document
from haystack.components.rankers import SentenceTransformersDiversityRanker
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order")
ranker.warm_up()
docs = [Document(content="Paris"), Document(content="Berlin")]
@ -41,7 +109,7 @@ class SentenceTransformersDiversityRanker:
output = ranker.run(query=query, documents=docs)
docs = output["documents"]
```
"""
""" # noqa: E501
def __init__(
self,
@ -49,14 +117,16 @@ class SentenceTransformersDiversityRanker:
top_k: int = 10,
device: Optional[ComponentDevice] = None,
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
similarity: Literal["dot_product", "cosine"] = "cosine",
similarity: Union[str, DiversityRankingSimilarity] = "cosine",
query_prefix: str = "",
query_suffix: str = "",
document_prefix: str = "",
document_suffix: str = "",
meta_fields_to_embed: Optional[List[str]] = None,
embedding_separator: str = "\n",
):
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
lambda_threshold: float = 0.5,
): # pylint: disable=too-many-positional-arguments
"""
Initialize a SentenceTransformersDiversityRanker.
@ -78,6 +148,10 @@ class SentenceTransformersDiversityRanker:
:param document_suffix: A string to add to the end of each Document text before ranking.
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
:param strategy: The strategy to use for diversity ranking. Can be either "greedy_diversity_order" or
"maximum_margin_relevance".
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
"maximum_margin_relevance".
"""
torch_and_sentence_transformers_import.check()
@ -88,15 +162,16 @@ class SentenceTransformersDiversityRanker:
self.device = ComponentDevice.resolve_device(device)
self.token = token
self.model = None
if similarity not in ["dot_product", "cosine"]:
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
self.similarity = similarity
self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity
self.query_prefix = query_prefix
self.document_prefix = document_prefix
self.query_suffix = query_suffix
self.document_suffix = document_suffix
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
self.lambda_threshold = lambda_threshold or 0.5
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
def warm_up(self):
"""
@ -119,16 +194,18 @@ class SentenceTransformersDiversityRanker:
return default_to_dict(
self,
model=self.model_name_or_path,
top_k=self.top_k,
device=self.device.to_dict(),
token=self.token.to_dict() if self.token else None,
top_k=self.top_k,
similarity=self.similarity,
similarity=str(self.similarity),
query_prefix=self.query_prefix,
document_prefix=self.document_prefix,
query_suffix=self.query_suffix,
document_prefix=self.document_prefix,
document_suffix=self.document_suffix,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
strategy=str(self.strategy),
lambda_threshold=self.lambda_threshold,
)
@classmethod
@ -182,14 +259,7 @@ class SentenceTransformersDiversityRanker:
"""
texts_to_embed = self._prepare_texts_to_embed(documents)
# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == "cosine":
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
n = len(documents)
selected: List[int] = []
@ -218,14 +288,84 @@ class SentenceTransformersDiversityRanker:
return ranked_docs
def _embed_and_normalize(self, query, texts_to_embed):
# Calculate embeddings
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
# Normalize embeddings to unit length for computing cosine similarity
if self.similarity == DiversityRankingSimilarity.COSINE:
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
return doc_embeddings, query_embedding
def _maximum_margin_relevance(
self, query: str, documents: List[Document], lambda_threshold: float, top_k: int
) -> List[Document]:
"""
Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores.
MMR scores are calculated for each document based on their relevance to the query and diversity from already
selected documents.
The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query
and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance
and diversity.
A closer value to 0 favors diversity, while a closer value to 1 favors relevance to the query.
See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries"
https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
"""
texts_to_embed = self._prepare_texts_to_embed(documents)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
top_k = top_k if top_k else len(documents)
selected: List[int] = []
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
query_similarities = query_similarities_as_tensor.reshape(-1)
idx = int(torch.argmax(query_similarities))
selected.append(idx)
while len(selected) < top_k:
best_idx = None
best_score = -float("inf")
for idx, _ in enumerate(documents):
if idx in selected:
continue
relevance_score = query_similarities[idx]
diversity_score = max(doc_embeddings[idx] @ doc_embeddings[j].T for j in selected)
mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score
if mmr_score > best_score:
best_score = mmr_score
best_idx = idx
if best_idx is None:
raise ValueError("No best document found, check if the documents list contains any documents.")
selected.append(best_idx)
return [documents[i] for i in selected]
@staticmethod
def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy):
if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")
@component.output_types(documents=List[Document])
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
def run(
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
lambda_threshold: Optional[float] = None,
) -> Dict[str, List[Document]]:
"""
Rank the documents based on their diversity.
:param query: The search query.
:param documents: List of Document objects to be ranker.
:param top_k: Optional. An integer to override the top_k set during initialization.
:param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when
strategy is "maximum_margin_relevance".
:returns: A dictionary with the following key:
- `documents`: List of Document objects that have been selected based on the diversity ranking.
@ -245,9 +385,17 @@ class SentenceTransformersDiversityRanker:
if top_k is None:
top_k = self.top_k
elif top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
elif not 0 < top_k <= len(documents):
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
if lambda_threshold is None:
lambda_threshold = self.lambda_threshold
self._check_lambda_threshold(lambda_threshold, self.strategy)
re_ranked_docs = self._maximum_margin_relevance(
query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k
)
else:
re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents)
return {"documents": diversity_sorted[:top_k]}
return {"documents": re_ranked_docs[:top_k]}

View File

@ -43,7 +43,7 @@ class TransformersSimilarityRanker:
```
"""
def __init__( # noqa: PLR0913
def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments
self,
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
device: Optional[ComponentDevice] = None,
@ -201,7 +201,7 @@ class TransformersSimilarityRanker:
return default_from_dict(cls, data)
@component.output_types(documents=List[Document])
def run(
def run( # pylint: disable=too-many-positional-arguments
self,
query: str,
documents: List[Document],

View File

@ -0,0 +1,4 @@
---
enhancements:
- |
Added the Maximum Margin Relevance (MMR) strategy to the `SentenceTransformersDiversityRanker`. MMR scores are calculated for each document based on their relevance to the query and diversity from already selected documents.

View File

@ -6,8 +6,12 @@ from unittest.mock import MagicMock, call, patch
import pytest
import torch
from haystack import Document
from haystack import Document, Pipeline
from haystack.components.rankers import SentenceTransformersDiversityRanker
from haystack.components.rankers.sentence_transformers_diversity import (
DiversityRankingSimilarity,
DiversityRankingStrategy,
)
from haystack.utils import ComponentDevice
from haystack.utils.auth import Secret
@ -27,7 +31,7 @@ class TestSentenceTransformersDiversityRanker:
assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert component.top_k == 10
assert component.device == ComponentDevice.resolve_device(None)
assert component.similarity == "cosine"
assert component.similarity == DiversityRankingSimilarity.COSINE
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert component.query_prefix == ""
assert component.document_prefix == ""
@ -36,7 +40,7 @@ class TestSentenceTransformersDiversityRanker:
assert component.meta_fields_to_embed == []
assert component.embedding_separator == "\n"
def test_init_with_custom_init_parameters(self):
def test_init_with_custom_parameters(self):
component = SentenceTransformersDiversityRanker(
model="sentence-transformers/msmarco-distilbert-base-v4",
top_k=5,
@ -53,7 +57,7 @@ class TestSentenceTransformersDiversityRanker:
assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert component.top_k == 5
assert component.device == ComponentDevice.from_str("cuda:0")
assert component.similarity == "dot_product"
assert component.similarity == DiversityRankingSimilarity.DOT_PRODUCT
assert component.token == Secret.from_token("fake-api-token")
assert component.query_prefix == "query:"
assert component.document_prefix == "document:"
@ -65,22 +69,26 @@ class TestSentenceTransformersDiversityRanker:
def test_to_dict(self):
component = SentenceTransformersDiversityRanker()
data = component.to_dict()
assert data == {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/all-MiniLM-L6-v2",
"top_k": 10,
"device": ComponentDevice.resolve_device(None).to_dict(),
"similarity": "cosine",
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
"query_prefix": "",
"document_prefix": "",
"query_suffix": "",
"document_suffix": "",
"meta_fields_to_embed": [],
"embedding_separator": "\n",
},
assert (
data["type"]
== "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker"
)
assert data["init_parameters"]["model"] == "sentence-transformers/all-MiniLM-L6-v2"
assert data["init_parameters"]["top_k"] == 10
assert data["init_parameters"]["device"] == ComponentDevice.resolve_device(None).to_dict()
assert data["init_parameters"]["similarity"] == "cosine"
assert data["init_parameters"]["token"] == {
"env_vars": ["HF_API_TOKEN", "HF_TOKEN"],
"strict": False,
"type": "env_var",
}
assert data["init_parameters"]["query_prefix"] == ""
assert data["init_parameters"]["document_prefix"] == ""
assert data["init_parameters"]["query_suffix"] == ""
assert data["init_parameters"]["document_suffix"] == ""
assert data["init_parameters"]["meta_fields_to_embed"] == []
assert data["init_parameters"]["embedding_separator"] == "\n"
assert data["init_parameters"]["strategy"] == "greedy_diversity_order"
def test_from_dict(self):
data = {
@ -104,7 +112,7 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == "cosine"
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
@ -135,7 +143,7 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == "cosine"
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
@ -154,7 +162,7 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
assert ranker.top_k == 10
assert ranker.device == ComponentDevice.resolve_device(None)
assert ranker.similarity == "cosine"
assert ranker.similarity == DiversityRankingSimilarity.COSINE
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
assert ranker.query_prefix == ""
assert ranker.document_prefix == ""
@ -163,7 +171,7 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.meta_fields_to_embed == []
assert ranker.embedding_separator == "\n"
def test_to_dict_with_custom_init_parameters(self):
def test_to_dict_with_custom_parameters(self):
component = SentenceTransformersDiversityRanker(
model="sentence-transformers/msmarco-distilbert-base-v4",
top_k=5,
@ -178,22 +186,23 @@ class TestSentenceTransformersDiversityRanker:
embedding_separator="--",
)
data = component.to_dict()
assert data == {
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
"init_parameters": {
"model": "sentence-transformers/msmarco-distilbert-base-v4",
"top_k": 5,
"device": ComponentDevice.from_str("cuda:0").to_dict(),
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
"similarity": "dot_product",
"query_prefix": "query:",
"document_prefix": "document:",
"query_suffix": "query suffix",
"document_suffix": "document suffix",
"meta_fields_to_embed": ["meta_field"],
"embedding_separator": "--",
},
}
assert (
data["type"]
== "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker"
)
assert data["init_parameters"]["model"] == "sentence-transformers/msmarco-distilbert-base-v4"
assert data["init_parameters"]["top_k"] == 5
assert data["init_parameters"]["device"] == ComponentDevice.from_str("cuda:0").to_dict()
assert data["init_parameters"]["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
assert data["init_parameters"]["similarity"] == "dot_product"
assert data["init_parameters"]["query_prefix"] == "query:"
assert data["init_parameters"]["document_prefix"] == "document:"
assert data["init_parameters"]["query_suffix"] == "query suffix"
assert data["init_parameters"]["document_suffix"] == "document suffix"
assert data["init_parameters"]["meta_fields_to_embed"] == ["meta_field"]
assert data["init_parameters"]["embedding_separator"] == "--"
assert data["init_parameters"]["strategy"] == "greedy_diversity_order"
def test_from_dict_with_custom_init_parameters(self):
data = {
@ -217,7 +226,7 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
assert ranker.top_k == 5
assert ranker.device == ComponentDevice.from_str("cuda:0")
assert ranker.similarity == "dot_product"
assert ranker.similarity == DiversityRankingSimilarity.DOT_PRODUCT
assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False)
assert ranker.query_prefix == "query:"
assert ranker.document_prefix == "document:"
@ -226,16 +235,24 @@ class TestSentenceTransformersDiversityRanker:
assert ranker.meta_fields_to_embed == ["meta_field"]
assert ranker.embedding_separator == "--"
def test_run_incorrect_similarity(self):
def test_run_invalid_similarity(self):
"""
Tests that run method raises ValueError if similarity is incorrect
"""
similarity = "incorrect"
with pytest.raises(
ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}."
):
with pytest.raises(ValueError, match=f"Unknown similarity metric"):
SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity)
def test_run_invalid_strategy(self):
"""
Tests that run method raises ValueError if strategy is incorrect
"""
strategy = "incorrect"
with pytest.raises(ValueError, match=f"Unknown strategy"):
SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy=strategy
)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_without_warm_up(self, similarity):
"""
@ -362,7 +379,7 @@ class TestSentenceTransformersDiversityRanker:
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
with pytest.raises(ValueError, match="top_k must be between"):
ranker.run(query=query, documents=documents, top_k=-5)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
@ -509,6 +526,52 @@ class TestSentenceTransformersDiversityRanker:
assert ranked_text == "Berlin Eiffel Tower Bananas"
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_maximum_margin_relevance(self, similarity):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "city"
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=0, top_k=3)
ranked_text = " ".join([doc.content for doc in ranked_docs])
assert ranked_text == "Berlin Eiffel Tower Bananas"
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_maximum_margin_relevance_with_given_lambda_threshold(self, similarity):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
)
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
query = "city"
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
ranker.model = MagicMock()
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=1, top_k=3)
ranked_text = " ".join([doc.content for doc in ranked_docs])
assert ranked_text == "Berlin Eiffel Tower Bananas"
def test_pipeline_serialise_deserialise(self):
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", top_k=5
)
pipe = Pipeline()
pipe.add_component("ranker", ranker)
pipe_serialized = pipe.dumps()
assert Pipeline.loads(pipe_serialized) == pipe
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run(self, similarity):
@ -607,3 +670,51 @@ class TestSentenceTransformersDiversityRanker:
# Check the order of ranked documents by comparing the content of the ranked documents
assert result_content == expected_content
@pytest.mark.integration
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
def test_run_with_maximum_margin_relevance_strategy(self, similarity):
query = "renewable energy sources"
docs = [
Document(content="18th-century French literature"),
Document(content="Solar power generation"),
Document(content="Ancient Egyptian hieroglyphics"),
Document(content="Wind turbine technology"),
Document(content="Baking sourdough bread"),
Document(content="Hydroelectric dam systems"),
Document(content="Geothermal energy extraction"),
Document(content="Biomass fuel production"),
]
ranker = SentenceTransformersDiversityRanker(
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, strategy="maximum_margin_relevance"
)
ranker.warm_up()
# lambda_threshold=1, the most relevant document should be returned first
results = ranker.run(query=query, documents=docs, lambda_threshold=1, top_k=len(docs))
expected = [
"Solar power generation",
"Wind turbine technology",
"Geothermal energy extraction",
"Hydroelectric dam systems",
"Biomass fuel production",
"Ancient Egyptian hieroglyphics",
"Baking sourdough bread",
"18th-century French literature",
]
assert [doc.content for doc in results["documents"]] == expected
# lambda_threshold=0, after the most relevant one, diverse documents should be returned
results = ranker.run(query=query, documents=docs, lambda_threshold=0, top_k=len(docs))
expected = [
"Solar power generation",
"Ancient Egyptian hieroglyphics",
"Baking sourdough bread",
"18th-century French literature",
"Biomass fuel production",
"Hydroelectric dam systems",
"Geothermal energy extraction",
"Wind turbine technology",
]
assert [doc.content for doc in results["documents"]] == expected