mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-24 21:48:52 +00:00
feat: adding Maximum Margin Relevance Ranker (#8554)
* initial import * linting * adding MRR tests * adding release notes * fixing tests * adding linting ignore to cross-encoder ranker * update docstring * refactoring * making strategy Optional instead of Literal * wip: adding unit tests * refactoring MMR algorithm * refactoring tests * cleaning up and updating tests * adding empty line between license + code * bug in tests * using Enum for strategy and similarity metric * adding more tests * adding empty line between license + code * removing run time params * PR comments * PR comments * fixing * fixing serialisation * fixing serialisation tests * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * Update haystack/components/rankers/sentence_transformers_diversity.py Co-authored-by: Daria Fokina <daria.fokina@deepset.ai> * fixing tests * PR comments * PR comments * PR comments * PR comments --------- Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
parent
a8eeb2024f
commit
b5a2fad642
@ -2,7 +2,8 @@
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from haystack import Document, component, default_from_dict, default_to_dict, logging
|
||||
from haystack.lazy_imports import LazyImport
|
||||
@ -16,24 +17,91 @@ with LazyImport(message="Run 'pip install \"sentence-transformers>=3.0.0\"'") as
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
|
||||
class DiversityRankingStrategy(Enum):
|
||||
"""
|
||||
The strategy to use for diversity ranking.
|
||||
"""
|
||||
|
||||
GREEDY_DIVERSITY_ORDER = "greedy_diversity_order"
|
||||
MAXIMUM_MARGIN_RELEVANCE = "maximum_margin_relevance"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Convert a Strategy enum to a string.
|
||||
"""
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def from_str(string: str) -> "DiversityRankingStrategy":
|
||||
"""
|
||||
Convert a string to a Strategy enum.
|
||||
"""
|
||||
enum_map = {e.value: e for e in DiversityRankingStrategy}
|
||||
strategy = enum_map.get(string)
|
||||
if strategy is None:
|
||||
msg = f"Unknown strategy '{string}'. Supported strategies are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return strategy
|
||||
|
||||
|
||||
class DiversityRankingSimilarity(Enum):
|
||||
"""
|
||||
The similarity metric to use for comparing embeddings.
|
||||
"""
|
||||
|
||||
DOT_PRODUCT = "dot_product"
|
||||
COSINE = "cosine"
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Convert a Similarity enum to a string.
|
||||
"""
|
||||
return self.value
|
||||
|
||||
@staticmethod
|
||||
def from_str(string: str) -> "DiversityRankingSimilarity":
|
||||
"""
|
||||
Convert a string to a Similarity enum.
|
||||
"""
|
||||
enum_map = {e.value: e for e in DiversityRankingSimilarity}
|
||||
similarity = enum_map.get(string)
|
||||
if similarity is None:
|
||||
msg = f"Unknown similarity metric '{string}'. Supported metrics are: {list(enum_map.keys())}"
|
||||
raise ValueError(msg)
|
||||
return similarity
|
||||
|
||||
|
||||
@component
|
||||
class SentenceTransformersDiversityRanker:
|
||||
"""
|
||||
A Diversity Ranker based on Sentence Transformers.
|
||||
|
||||
Implements a document ranking algorithm that orders documents in such a way as to maximize the overall diversity
|
||||
of the documents.
|
||||
Applies a document ranking algorithm based on one of the two strategies:
|
||||
|
||||
This component provides functionality to rank a list of documents based on their similarity with respect to the
|
||||
query to maximize the overall diversity. It uses a pre-trained Sentence Transformers model to embed the query and
|
||||
the Documents.
|
||||
1. Greedy Diversity Order:
|
||||
|
||||
Usage example:
|
||||
Implements a document ranking algorithm that orders documents in a way that maximizes the overall diversity
|
||||
of the documents based on their similarity to the query.
|
||||
|
||||
It uses a pre-trained Sentence Transformers model to embed the query and
|
||||
the documents.
|
||||
|
||||
2. Maximum Margin Relevance:
|
||||
|
||||
Implements a document ranking algorithm that orders documents based on their Maximum Margin Relevance (MMR)
|
||||
scores.
|
||||
|
||||
MMR scores are calculated for each document based on their relevance to the query and diversity from already
|
||||
selected documents. The algorithm iteratively selects documents based on their MMR scores, balancing between
|
||||
relevance to the query and diversity from already selected documents. The 'lambda_threshold' controls the
|
||||
trade-off between relevance and diversity.
|
||||
|
||||
### Usage example
|
||||
```python
|
||||
from haystack import Document
|
||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||
|
||||
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine")
|
||||
ranker = SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy="greedy_diversity_order")
|
||||
ranker.warm_up()
|
||||
|
||||
docs = [Document(content="Paris"), Document(content="Berlin")]
|
||||
@ -41,7 +109,7 @@ class SentenceTransformersDiversityRanker:
|
||||
output = ranker.run(query=query, documents=docs)
|
||||
docs = output["documents"]
|
||||
```
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -49,14 +117,16 @@ class SentenceTransformersDiversityRanker:
|
||||
top_k: int = 10,
|
||||
device: Optional[ComponentDevice] = None,
|
||||
token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
|
||||
similarity: Literal["dot_product", "cosine"] = "cosine",
|
||||
similarity: Union[str, DiversityRankingSimilarity] = "cosine",
|
||||
query_prefix: str = "",
|
||||
query_suffix: str = "",
|
||||
document_prefix: str = "",
|
||||
document_suffix: str = "",
|
||||
meta_fields_to_embed: Optional[List[str]] = None,
|
||||
embedding_separator: str = "\n",
|
||||
):
|
||||
strategy: Union[str, DiversityRankingStrategy] = "greedy_diversity_order",
|
||||
lambda_threshold: float = 0.5,
|
||||
): # pylint: disable=too-many-positional-arguments
|
||||
"""
|
||||
Initialize a SentenceTransformersDiversityRanker.
|
||||
|
||||
@ -78,6 +148,10 @@ class SentenceTransformersDiversityRanker:
|
||||
:param document_suffix: A string to add to the end of each Document text before ranking.
|
||||
:param meta_fields_to_embed: List of meta fields that should be embedded along with the Document content.
|
||||
:param embedding_separator: Separator used to concatenate the meta fields to the Document content.
|
||||
:param strategy: The strategy to use for diversity ranking. Can be either "greedy_diversity_order" or
|
||||
"maximum_margin_relevance".
|
||||
:param lambda_threshold: The trade-off parameter between relevance and diversity. Only used when strategy is
|
||||
"maximum_margin_relevance".
|
||||
"""
|
||||
torch_and_sentence_transformers_import.check()
|
||||
|
||||
@ -88,15 +162,16 @@ class SentenceTransformersDiversityRanker:
|
||||
self.device = ComponentDevice.resolve_device(device)
|
||||
self.token = token
|
||||
self.model = None
|
||||
if similarity not in ["dot_product", "cosine"]:
|
||||
raise ValueError(f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}.")
|
||||
self.similarity = similarity
|
||||
self.similarity = DiversityRankingSimilarity.from_str(similarity) if isinstance(similarity, str) else similarity
|
||||
self.query_prefix = query_prefix
|
||||
self.document_prefix = document_prefix
|
||||
self.query_suffix = query_suffix
|
||||
self.document_suffix = document_suffix
|
||||
self.meta_fields_to_embed = meta_fields_to_embed or []
|
||||
self.embedding_separator = embedding_separator
|
||||
self.strategy = DiversityRankingStrategy.from_str(strategy) if isinstance(strategy, str) else strategy
|
||||
self.lambda_threshold = lambda_threshold or 0.5
|
||||
self._check_lambda_threshold(self.lambda_threshold, self.strategy)
|
||||
|
||||
def warm_up(self):
|
||||
"""
|
||||
@ -119,16 +194,18 @@ class SentenceTransformersDiversityRanker:
|
||||
return default_to_dict(
|
||||
self,
|
||||
model=self.model_name_or_path,
|
||||
top_k=self.top_k,
|
||||
device=self.device.to_dict(),
|
||||
token=self.token.to_dict() if self.token else None,
|
||||
top_k=self.top_k,
|
||||
similarity=self.similarity,
|
||||
similarity=str(self.similarity),
|
||||
query_prefix=self.query_prefix,
|
||||
document_prefix=self.document_prefix,
|
||||
query_suffix=self.query_suffix,
|
||||
document_prefix=self.document_prefix,
|
||||
document_suffix=self.document_suffix,
|
||||
meta_fields_to_embed=self.meta_fields_to_embed,
|
||||
embedding_separator=self.embedding_separator,
|
||||
strategy=str(self.strategy),
|
||||
lambda_threshold=self.lambda_threshold,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -182,14 +259,7 @@ class SentenceTransformersDiversityRanker:
|
||||
"""
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents)
|
||||
|
||||
# Calculate embeddings
|
||||
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
|
||||
# Normalize embeddings to unit length for computing cosine similarity
|
||||
if self.similarity == "cosine":
|
||||
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
||||
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
||||
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
|
||||
|
||||
n = len(documents)
|
||||
selected: List[int] = []
|
||||
@ -218,14 +288,84 @@ class SentenceTransformersDiversityRanker:
|
||||
|
||||
return ranked_docs
|
||||
|
||||
def _embed_and_normalize(self, query, texts_to_embed):
|
||||
# Calculate embeddings
|
||||
doc_embeddings = self.model.encode(texts_to_embed, convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
query_embedding = self.model.encode([self.query_prefix + query + self.query_suffix], convert_to_tensor=True) # type: ignore[attr-defined]
|
||||
|
||||
# Normalize embeddings to unit length for computing cosine similarity
|
||||
if self.similarity == DiversityRankingSimilarity.COSINE:
|
||||
doc_embeddings /= torch.norm(doc_embeddings, p=2, dim=-1).unsqueeze(-1)
|
||||
query_embedding /= torch.norm(query_embedding, p=2, dim=-1).unsqueeze(-1)
|
||||
return doc_embeddings, query_embedding
|
||||
|
||||
def _maximum_margin_relevance(
|
||||
self, query: str, documents: List[Document], lambda_threshold: float, top_k: int
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Orders the given list of documents according to the Maximum Margin Relevance (MMR) scores.
|
||||
|
||||
MMR scores are calculated for each document based on their relevance to the query and diversity from already
|
||||
selected documents.
|
||||
|
||||
The algorithm iteratively selects documents based on their MMR scores, balancing between relevance to the query
|
||||
and diversity from already selected documents. The 'lambda_threshold' controls the trade-off between relevance
|
||||
and diversity.
|
||||
|
||||
A closer value to 0 favors diversity, while a closer value to 1 favors relevance to the query.
|
||||
|
||||
See : "The Use of MMR, Diversity-Based Reranking for Reordering Documents and Producing Summaries"
|
||||
https://www.cs.cmu.edu/~jgc/publication/The_Use_MMR_Diversity_Based_LTMIR_1998.pdf
|
||||
"""
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents)
|
||||
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
|
||||
top_k = top_k if top_k else len(documents)
|
||||
|
||||
selected: List[int] = []
|
||||
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
|
||||
query_similarities = query_similarities_as_tensor.reshape(-1)
|
||||
idx = int(torch.argmax(query_similarities))
|
||||
selected.append(idx)
|
||||
while len(selected) < top_k:
|
||||
best_idx = None
|
||||
best_score = -float("inf")
|
||||
for idx, _ in enumerate(documents):
|
||||
if idx in selected:
|
||||
continue
|
||||
relevance_score = query_similarities[idx]
|
||||
diversity_score = max(doc_embeddings[idx] @ doc_embeddings[j].T for j in selected)
|
||||
mmr_score = lambda_threshold * relevance_score - (1 - lambda_threshold) * diversity_score
|
||||
if mmr_score > best_score:
|
||||
best_score = mmr_score
|
||||
best_idx = idx
|
||||
if best_idx is None:
|
||||
raise ValueError("No best document found, check if the documents list contains any documents.")
|
||||
selected.append(best_idx)
|
||||
|
||||
return [documents[i] for i in selected]
|
||||
|
||||
@staticmethod
|
||||
def _check_lambda_threshold(lambda_threshold: float, strategy: DiversityRankingStrategy):
|
||||
if (strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE) and not 0 <= lambda_threshold <= 1:
|
||||
raise ValueError(f"lambda_threshold must be between 0 and 1, but got {lambda_threshold}.")
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(self, query: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
top_k: Optional[int] = None,
|
||||
lambda_threshold: Optional[float] = None,
|
||||
) -> Dict[str, List[Document]]:
|
||||
"""
|
||||
Rank the documents based on their diversity.
|
||||
|
||||
:param query: The search query.
|
||||
:param documents: List of Document objects to be ranker.
|
||||
:param top_k: Optional. An integer to override the top_k set during initialization.
|
||||
:param lambda_threshold: Override the trade-off parameter between relevance and diversity. Only used when
|
||||
strategy is "maximum_margin_relevance".
|
||||
|
||||
:returns: A dictionary with the following key:
|
||||
- `documents`: List of Document objects that have been selected based on the diversity ranking.
|
||||
@ -245,9 +385,17 @@ class SentenceTransformersDiversityRanker:
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
elif top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
elif not 0 < top_k <= len(documents):
|
||||
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
|
||||
|
||||
diversity_sorted = self._greedy_diversity_order(query=query, documents=documents)
|
||||
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
|
||||
if lambda_threshold is None:
|
||||
lambda_threshold = self.lambda_threshold
|
||||
self._check_lambda_threshold(lambda_threshold, self.strategy)
|
||||
re_ranked_docs = self._maximum_margin_relevance(
|
||||
query=query, documents=documents, lambda_threshold=lambda_threshold, top_k=top_k
|
||||
)
|
||||
else:
|
||||
re_ranked_docs = self._greedy_diversity_order(query=query, documents=documents)
|
||||
|
||||
return {"documents": diversity_sorted[:top_k]}
|
||||
return {"documents": re_ranked_docs[:top_k]}
|
||||
|
||||
@ -43,7 +43,7 @@ class TransformersSimilarityRanker:
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__( # noqa: PLR0913
|
||||
def __init__( # noqa: PLR0913, pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
model: Union[str, Path] = "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
||||
device: Optional[ComponentDevice] = None,
|
||||
@ -201,7 +201,7 @@ class TransformersSimilarityRanker:
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@component.output_types(documents=List[Document])
|
||||
def run(
|
||||
def run( # pylint: disable=too-many-positional-arguments
|
||||
self,
|
||||
query: str,
|
||||
documents: List[Document],
|
||||
|
||||
@ -0,0 +1,4 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Added the Maximum Margin Relevance (MMR) strategy to the `SentenceTransformersDiversityRanker`. MMR scores are calculated for each document based on their relevance to the query and diversity from already selected documents.
|
||||
@ -6,8 +6,12 @@ from unittest.mock import MagicMock, call, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from haystack import Document
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.components.rankers import SentenceTransformersDiversityRanker
|
||||
from haystack.components.rankers.sentence_transformers_diversity import (
|
||||
DiversityRankingSimilarity,
|
||||
DiversityRankingStrategy,
|
||||
)
|
||||
from haystack.utils import ComponentDevice
|
||||
from haystack.utils.auth import Secret
|
||||
|
||||
@ -27,7 +31,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert component.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert component.top_k == 10
|
||||
assert component.device == ComponentDevice.resolve_device(None)
|
||||
assert component.similarity == "cosine"
|
||||
assert component.similarity == DiversityRankingSimilarity.COSINE
|
||||
assert component.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert component.query_prefix == ""
|
||||
assert component.document_prefix == ""
|
||||
@ -36,7 +40,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert component.meta_fields_to_embed == []
|
||||
assert component.embedding_separator == "\n"
|
||||
|
||||
def test_init_with_custom_init_parameters(self):
|
||||
def test_init_with_custom_parameters(self):
|
||||
component = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||
top_k=5,
|
||||
@ -53,7 +57,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert component.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
|
||||
assert component.top_k == 5
|
||||
assert component.device == ComponentDevice.from_str("cuda:0")
|
||||
assert component.similarity == "dot_product"
|
||||
assert component.similarity == DiversityRankingSimilarity.DOT_PRODUCT
|
||||
assert component.token == Secret.from_token("fake-api-token")
|
||||
assert component.query_prefix == "query:"
|
||||
assert component.document_prefix == "document:"
|
||||
@ -65,22 +69,26 @@ class TestSentenceTransformersDiversityRanker:
|
||||
def test_to_dict(self):
|
||||
component = SentenceTransformersDiversityRanker()
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
"top_k": 10,
|
||||
"device": ComponentDevice.resolve_device(None).to_dict(),
|
||||
"similarity": "cosine",
|
||||
"token": {"env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False, "type": "env_var"},
|
||||
"query_prefix": "",
|
||||
"document_prefix": "",
|
||||
"query_suffix": "",
|
||||
"document_suffix": "",
|
||||
"meta_fields_to_embed": [],
|
||||
"embedding_separator": "\n",
|
||||
},
|
||||
assert (
|
||||
data["type"]
|
||||
== "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker"
|
||||
)
|
||||
assert data["init_parameters"]["model"] == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert data["init_parameters"]["top_k"] == 10
|
||||
assert data["init_parameters"]["device"] == ComponentDevice.resolve_device(None).to_dict()
|
||||
assert data["init_parameters"]["similarity"] == "cosine"
|
||||
assert data["init_parameters"]["token"] == {
|
||||
"env_vars": ["HF_API_TOKEN", "HF_TOKEN"],
|
||||
"strict": False,
|
||||
"type": "env_var",
|
||||
}
|
||||
assert data["init_parameters"]["query_prefix"] == ""
|
||||
assert data["init_parameters"]["document_prefix"] == ""
|
||||
assert data["init_parameters"]["query_suffix"] == ""
|
||||
assert data["init_parameters"]["document_suffix"] == ""
|
||||
assert data["init_parameters"]["meta_fields_to_embed"] == []
|
||||
assert data["init_parameters"]["embedding_separator"] == "\n"
|
||||
assert data["init_parameters"]["strategy"] == "greedy_diversity_order"
|
||||
|
||||
def test_from_dict(self):
|
||||
data = {
|
||||
@ -104,7 +112,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||
assert ranker.similarity == "cosine"
|
||||
assert ranker.similarity == DiversityRankingSimilarity.COSINE
|
||||
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert ranker.query_prefix == ""
|
||||
assert ranker.document_prefix == ""
|
||||
@ -135,7 +143,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||
assert ranker.similarity == "cosine"
|
||||
assert ranker.similarity == DiversityRankingSimilarity.COSINE
|
||||
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert ranker.query_prefix == ""
|
||||
assert ranker.document_prefix == ""
|
||||
@ -154,7 +162,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.model_name_or_path == "sentence-transformers/all-MiniLM-L6-v2"
|
||||
assert ranker.top_k == 10
|
||||
assert ranker.device == ComponentDevice.resolve_device(None)
|
||||
assert ranker.similarity == "cosine"
|
||||
assert ranker.similarity == DiversityRankingSimilarity.COSINE
|
||||
assert ranker.token == Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False)
|
||||
assert ranker.query_prefix == ""
|
||||
assert ranker.document_prefix == ""
|
||||
@ -163,7 +171,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.meta_fields_to_embed == []
|
||||
assert ranker.embedding_separator == "\n"
|
||||
|
||||
def test_to_dict_with_custom_init_parameters(self):
|
||||
def test_to_dict_with_custom_parameters(self):
|
||||
component = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/msmarco-distilbert-base-v4",
|
||||
top_k=5,
|
||||
@ -178,22 +186,23 @@ class TestSentenceTransformersDiversityRanker:
|
||||
embedding_separator="--",
|
||||
)
|
||||
data = component.to_dict()
|
||||
assert data == {
|
||||
"type": "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker",
|
||||
"init_parameters": {
|
||||
"model": "sentence-transformers/msmarco-distilbert-base-v4",
|
||||
"top_k": 5,
|
||||
"device": ComponentDevice.from_str("cuda:0").to_dict(),
|
||||
"token": {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"},
|
||||
"similarity": "dot_product",
|
||||
"query_prefix": "query:",
|
||||
"document_prefix": "document:",
|
||||
"query_suffix": "query suffix",
|
||||
"document_suffix": "document suffix",
|
||||
"meta_fields_to_embed": ["meta_field"],
|
||||
"embedding_separator": "--",
|
||||
},
|
||||
}
|
||||
|
||||
assert (
|
||||
data["type"]
|
||||
== "haystack.components.rankers.sentence_transformers_diversity.SentenceTransformersDiversityRanker"
|
||||
)
|
||||
assert data["init_parameters"]["model"] == "sentence-transformers/msmarco-distilbert-base-v4"
|
||||
assert data["init_parameters"]["top_k"] == 5
|
||||
assert data["init_parameters"]["device"] == ComponentDevice.from_str("cuda:0").to_dict()
|
||||
assert data["init_parameters"]["token"] == {"env_vars": ["ENV_VAR"], "strict": False, "type": "env_var"}
|
||||
assert data["init_parameters"]["similarity"] == "dot_product"
|
||||
assert data["init_parameters"]["query_prefix"] == "query:"
|
||||
assert data["init_parameters"]["document_prefix"] == "document:"
|
||||
assert data["init_parameters"]["query_suffix"] == "query suffix"
|
||||
assert data["init_parameters"]["document_suffix"] == "document suffix"
|
||||
assert data["init_parameters"]["meta_fields_to_embed"] == ["meta_field"]
|
||||
assert data["init_parameters"]["embedding_separator"] == "--"
|
||||
assert data["init_parameters"]["strategy"] == "greedy_diversity_order"
|
||||
|
||||
def test_from_dict_with_custom_init_parameters(self):
|
||||
data = {
|
||||
@ -217,7 +226,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.model_name_or_path == "sentence-transformers/msmarco-distilbert-base-v4"
|
||||
assert ranker.top_k == 5
|
||||
assert ranker.device == ComponentDevice.from_str("cuda:0")
|
||||
assert ranker.similarity == "dot_product"
|
||||
assert ranker.similarity == DiversityRankingSimilarity.DOT_PRODUCT
|
||||
assert ranker.token == Secret.from_env_var("ENV_VAR", strict=False)
|
||||
assert ranker.query_prefix == "query:"
|
||||
assert ranker.document_prefix == "document:"
|
||||
@ -226,16 +235,24 @@ class TestSentenceTransformersDiversityRanker:
|
||||
assert ranker.meta_fields_to_embed == ["meta_field"]
|
||||
assert ranker.embedding_separator == "--"
|
||||
|
||||
def test_run_incorrect_similarity(self):
|
||||
def test_run_invalid_similarity(self):
|
||||
"""
|
||||
Tests that run method raises ValueError if similarity is incorrect
|
||||
"""
|
||||
similarity = "incorrect"
|
||||
with pytest.raises(
|
||||
ValueError, match=f"Similarity must be one of 'dot_product' or 'cosine', but got {similarity}."
|
||||
):
|
||||
with pytest.raises(ValueError, match=f"Unknown similarity metric"):
|
||||
SentenceTransformersDiversityRanker(model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity)
|
||||
|
||||
def test_run_invalid_strategy(self):
|
||||
"""
|
||||
Tests that run method raises ValueError if strategy is incorrect
|
||||
"""
|
||||
strategy = "incorrect"
|
||||
with pytest.raises(ValueError, match=f"Unknown strategy"):
|
||||
SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", strategy=strategy
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_without_warm_up(self, similarity):
|
||||
"""
|
||||
@ -362,7 +379,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
query = "test"
|
||||
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be > 0, but got"):
|
||||
with pytest.raises(ValueError, match="top_k must be between"):
|
||||
ranker.run(query=query, documents=documents, top_k=-5)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
@ -509,6 +526,52 @@ class TestSentenceTransformersDiversityRanker:
|
||||
|
||||
assert ranked_text == "Berlin Eiffel Tower Bananas"
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_maximum_margin_relevance(self, similarity):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
|
||||
query = "city"
|
||||
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
|
||||
ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=0, top_k=3)
|
||||
ranked_text = " ".join([doc.content for doc in ranked_docs])
|
||||
|
||||
assert ranked_text == "Berlin Eiffel Tower Bananas"
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_maximum_margin_relevance_with_given_lambda_threshold(self, similarity):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity
|
||||
)
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
|
||||
query = "city"
|
||||
documents = [Document(content="Eiffel Tower"), Document(content="Berlin"), Document(content="Bananas")]
|
||||
ranker.model = MagicMock()
|
||||
ranker.model.encode = MagicMock(side_effect=mock_encode_response)
|
||||
|
||||
ranked_docs = ranker._maximum_margin_relevance(query=query, documents=documents, lambda_threshold=1, top_k=3)
|
||||
ranked_text = " ".join([doc.content for doc in ranked_docs])
|
||||
|
||||
assert ranked_text == "Berlin Eiffel Tower Bananas"
|
||||
|
||||
def test_pipeline_serialise_deserialise(self):
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity="cosine", top_k=5
|
||||
)
|
||||
|
||||
pipe = Pipeline()
|
||||
pipe.add_component("ranker", ranker)
|
||||
pipe_serialized = pipe.dumps()
|
||||
assert Pipeline.loads(pipe_serialized) == pipe
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run(self, similarity):
|
||||
@ -607,3 +670,51 @@ class TestSentenceTransformersDiversityRanker:
|
||||
|
||||
# Check the order of ranked documents by comparing the content of the ranked documents
|
||||
assert result_content == expected_content
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
def test_run_with_maximum_margin_relevance_strategy(self, similarity):
|
||||
query = "renewable energy sources"
|
||||
docs = [
|
||||
Document(content="18th-century French literature"),
|
||||
Document(content="Solar power generation"),
|
||||
Document(content="Ancient Egyptian hieroglyphics"),
|
||||
Document(content="Wind turbine technology"),
|
||||
Document(content="Baking sourdough bread"),
|
||||
Document(content="Hydroelectric dam systems"),
|
||||
Document(content="Geothermal energy extraction"),
|
||||
Document(content="Biomass fuel production"),
|
||||
]
|
||||
|
||||
ranker = SentenceTransformersDiversityRanker(
|
||||
model="sentence-transformers/all-MiniLM-L6-v2", similarity=similarity, strategy="maximum_margin_relevance"
|
||||
)
|
||||
ranker.warm_up()
|
||||
|
||||
# lambda_threshold=1, the most relevant document should be returned first
|
||||
results = ranker.run(query=query, documents=docs, lambda_threshold=1, top_k=len(docs))
|
||||
expected = [
|
||||
"Solar power generation",
|
||||
"Wind turbine technology",
|
||||
"Geothermal energy extraction",
|
||||
"Hydroelectric dam systems",
|
||||
"Biomass fuel production",
|
||||
"Ancient Egyptian hieroglyphics",
|
||||
"Baking sourdough bread",
|
||||
"18th-century French literature",
|
||||
]
|
||||
assert [doc.content for doc in results["documents"]] == expected
|
||||
|
||||
# lambda_threshold=0, after the most relevant one, diverse documents should be returned
|
||||
results = ranker.run(query=query, documents=docs, lambda_threshold=0, top_k=len(docs))
|
||||
expected = [
|
||||
"Solar power generation",
|
||||
"Ancient Egyptian hieroglyphics",
|
||||
"Baking sourdough bread",
|
||||
"18th-century French literature",
|
||||
"Biomass fuel production",
|
||||
"Hydroelectric dam systems",
|
||||
"Geothermal energy extraction",
|
||||
"Wind turbine technology",
|
||||
]
|
||||
assert [doc.content for doc in results["documents"]] == expected
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user