diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index fa1cfe3f0..648a3baa0 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -353,7 +353,7 @@ class SentenceTransformersDiversityRanker: texts_to_embed = self._prepare_texts_to_embed(documents) doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed) - top_k = top_k if top_k else len(documents) + top_k = min(top_k, len(documents)) selected: list[int] = [] query_similarities_as_tensor = query_embedding @ doc_embeddings.T @@ -375,9 +375,8 @@ class SentenceTransformersDiversityRanker: if mmr_score > best_score: best_score = mmr_score best_idx = idx - if best_idx is None: - raise ValueError("No best document found, check if the documents list contains any documents.") - selected.append(best_idx) + # loop condition ensures unselected docs exist with valid scores + selected.append(best_idx) # type: ignore[arg-type] return [documents[i] for i in selected] @@ -421,8 +420,8 @@ class SentenceTransformersDiversityRanker: if top_k is None: top_k = self.top_k - elif not 0 < top_k <= len(documents): - raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}") + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE: if lambda_threshold is None: diff --git a/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml b/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml new file mode 100644 index 000000000..1ea67a100 --- /dev/null +++ b/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Ensure consistent behavior in `SentenceTransformersDiversityRanker`. Like other rankers, it now returns + all documents instead of raising an error when `top_k` exceeds the number of available documents. diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index 5a9cd2bdc..8a050d3fe 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -381,7 +381,7 @@ class TestSentenceTransformersDiversityRanker: query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] - with pytest.raises(ValueError, match="top_k must be between"): + with pytest.raises(ValueError, match="top_k must be > 0"): ranker.run(query=query, documents=documents, top_k=-5) @pytest.mark.parametrize("similarity", ["dot_product", "cosine"])