mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-21 03:58:43 +00:00
fix: fix inconsistent top_k validation in SentenceTransformersDiversityRanker (#9698)
* Fix inconsistent top_k validation in SentenceTransformersDiversityRanker - change elif to if in run() method to ensure top_k validation always runs regardless of whatever top_k comes from init or runtime - Both scenarios now consistently raise ValueError with descriptive message format: 'top_k must be between 1 and X, but got Y' - Fixes inconsistency where init top_k gave confusing MMR error while runtime top_k gave clear validation error * improvements --------- Co-authored-by: Stefano Fiorucci <stefanofiorucci@gmail.com>
This commit is contained in:
parent
9ce48f9509
commit
ae6f3bcf7c
@ -353,7 +353,7 @@ class SentenceTransformersDiversityRanker:
|
||||
|
||||
texts_to_embed = self._prepare_texts_to_embed(documents)
|
||||
doc_embeddings, query_embedding = self._embed_and_normalize(query, texts_to_embed)
|
||||
top_k = top_k if top_k else len(documents)
|
||||
top_k = min(top_k, len(documents))
|
||||
|
||||
selected: list[int] = []
|
||||
query_similarities_as_tensor = query_embedding @ doc_embeddings.T
|
||||
@ -375,9 +375,8 @@ class SentenceTransformersDiversityRanker:
|
||||
if mmr_score > best_score:
|
||||
best_score = mmr_score
|
||||
best_idx = idx
|
||||
if best_idx is None:
|
||||
raise ValueError("No best document found, check if the documents list contains any documents.")
|
||||
selected.append(best_idx)
|
||||
# loop condition ensures unselected docs exist with valid scores
|
||||
selected.append(best_idx) # type: ignore[arg-type]
|
||||
|
||||
return [documents[i] for i in selected]
|
||||
|
||||
@ -421,8 +420,8 @@ class SentenceTransformersDiversityRanker:
|
||||
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
elif not 0 < top_k <= len(documents):
|
||||
raise ValueError(f"top_k must be between 1 and {len(documents)}, but got {top_k}")
|
||||
if top_k <= 0:
|
||||
raise ValueError(f"top_k must be > 0, but got {top_k}")
|
||||
|
||||
if self.strategy == DiversityRankingStrategy.MAXIMUM_MARGIN_RELEVANCE:
|
||||
if lambda_threshold is None:
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
---
|
||||
fixes:
|
||||
- |
|
||||
Ensure consistent behavior in `SentenceTransformersDiversityRanker`. Like other rankers, it now returns
|
||||
all documents instead of raising an error when `top_k` exceeds the number of available documents.
|
||||
@ -381,7 +381,7 @@ class TestSentenceTransformersDiversityRanker:
|
||||
query = "test"
|
||||
documents = [Document(content="doc1"), Document(content="doc2"), Document(content="doc3")]
|
||||
|
||||
with pytest.raises(ValueError, match="top_k must be between"):
|
||||
with pytest.raises(ValueError, match="top_k must be > 0"):
|
||||
ranker.run(query=query, documents=documents, top_k=-5)
|
||||
|
||||
@pytest.mark.parametrize("similarity", ["dot_product", "cosine"])
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user