fix: fix inconsistent top_k validation in SentenceTransformersDiversityRanker (#9698)

* Fix inconsistent top_k validation in SentenceTransformersDiversityRanker
- change elif to if in run() method to ensure top_k validation always
  runs regardless of whatever top_k comes from init or runtime
- Both scenarios now consistently raise ValueError with descriptive
  message format: 'top_k must be between 1 and X, but got Y'
- Fixes inconsistency where init top_k gave confusing MMR error while
  runtime top_k gave clear validation error

* improvements

---------

Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
Saurabh Lingam 2025-08-14 21:04:29 +05:30 committed by GitHub
parent 9ce48f9509
commit ae6f3bcf7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 7 deletions

View File

@ -353,7 +353,7 @@ class SentenceTransformersDiversityRanker:
texts_to_embed = self._prepare_texts_to_embed(documents)
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
top_k = top_k if top_k else len(documents)
top_k = min(top_k, len(documents))
selected: list[int] = []
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
@ -375,9 +375,8 @@ class SentenceTransformersDiversityRanker:
if mmr_score > best_score:
best_score = mmr_score
best_idx = idx
if best_idx is None:
raise ValueError("No best document found, check if the documents list contains any documents.")
selected.append(best_idx)
# loop condition ensures unselected docs exist with valid scores
selected.append(best_idx) # type: ignore[arg-type]
return [documents[i] for i in selected]
@ -421,8 +420,8 @@ class SentenceTransformersDiversityRanker:
if top_k is None:
top_k = self.top_k
elif not 0 < top_k <= len(documents):
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
if top_k <= 0:
raise ValueError(f"top_k must be > 0, but got {top_k}")
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
if lambda_threshold is None:

View File

@ -0,0 +1,5 @@
---
fixes:
- |
Ensure consistent behavior in `SentenceTransformersDiversityRanker`. Like other rankers, it now returns
all documents instead of raising an error when `top_k` exceeds the number of available documents.

View File

@ -381,7 +381,7 @@ class TestSentenceTransformersDiversityRanker:
query = "test"
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
with pytest.raises(ValueError, match="top_k must be between"):
with pytest.raises(ValueError, match="top_k must be > 0"):
ranker.run(query=query, documents=documents, top_k=-5)
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])