From ae6f3bcf7cda379f4f060ac71acb69ad3c9b6dba Mon Sep 17 00:00:00 2001 From: Saurabh Lingam <141753821+SaurabhLingam@users.noreply.github.com> Date: Thu, 14 Aug 2025 21:04:29 +0530 Subject: [PATCH] fix: fix inconsistent `top_k` validation in `SentenceTransformersDiversityRanker` (#9698) * Fix inconsistent top_k validation in SentenceTransformersDiversityRanker - change elif to if in run() method to ensure top_k validation always runs regardless of whatever top_k comes from init or runtime - Both scenarios now consistently raise ValueError with descriptive message format: 'top_k must be between 1 and X, but got Y' - Fixes inconsistency where init top_k gave confusing MMR error while runtime top_k gave clear validation error * improvements --------- Co-authored-by: Stefano Fiorucci --- .../rankers/sentence_transformers_diversity.py | 11 +++++------ ...ity-ranker-top-k-consistency-29535a76a07a8ff3.yaml | 5 +++++ .../rankers/test_sentence_transformers_diversity.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) create mode 100644 releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml diff --git a/haystack/components/rankers/sentence_transformers_diversity.py b/haystack/components/rankers/sentence_transformers_diversity.py index fa1cfe3f0..648a3baa0 100644 --- a/haystack/components/rankers/sentence_transformers_diversity.py +++ b/haystack/components/rankers/sentence_transformers_diversity.py @@ -353,7 +353,7 @@ class SentenceTransformersDiversityRanker: texts_to_embed = self._prepare_texts_to_embed(documents) doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed) - top_k = top_k if top_k else len(documents) + top_k = min(top_k, len(documents)) selected: list[int] = [] query_similarities_as_tensor = query_embedding @ doc_embeddings.T @@ -375,9 +375,8 @@ class SentenceTransformersDiversityRanker: if mmr_score > best_score: best_score = mmr_score best_idx = idx - if best_idx is None: - raise ValueError("No best document found, check if the documents list contains any documents.") - selected.append(best_idx) + # loop condition ensures unselected docs exist with valid scores + selected.append(best_idx) # type: ignore[arg-type] return [documents[i] for i in selected] @@ -421,8 +420,8 @@ class SentenceTransformersDiversityRanker: if top_k is None: top_k = self.top_k - elif not 0 < top_k <= len(documents): - raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}") + if top_k <= 0: + raise ValueError(f"top_k must be > 0, but got {top_k}") if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE: if lambda_threshold is None: diff --git a/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml b/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml new file mode 100644 index 000000000..1ea67a100 --- /dev/null +++ b/releasenotes/notes/st-diversity-ranker-top-k-consistency-29535a76a07a8ff3.yaml @@ -0,0 +1,5 @@ +--- +fixes: + - | + Ensure consistent behavior in `SentenceTransformersDiversityRanker`. Like other rankers, it now returns + all documents instead of raising an error when `top_k` exceeds the number of available documents. diff --git a/test/components/rankers/test_sentence_transformers_diversity.py b/test/components/rankers/test_sentence_transformers_diversity.py index 5a9cd2bdc..8a050d3fe 100644 --- a/test/components/rankers/test_sentence_transformers_diversity.py +++ b/test/components/rankers/test_sentence_transformers_diversity.py @@ -381,7 +381,7 @@ class TestSentenceTransformersDiversityRanker: query = "test" documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")] - with pytest.raises(ValueError, match="top_k must be between"): + with pytest.raises(ValueError, match="top_k must be > 0"): ranker.run(query=query, documents=documents, top_k=-5) @pytest.mark.parametrize("similarity", ["dot_product", "cosine"])