diff --git a/haystack/nodes/ranker/sentence_transformers.py b/haystack/nodes/ranker/sentence_transformers.py index 5b289871f..7648cd2d8 100644 --- a/haystack/nodes/ranker/sentence_transformers.py +++ b/haystack/nodes/ranker/sentence_transformers.py @@ -225,7 +225,7 @@ class SentenceTransformersRanker(BaseRanker): logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim] if single_list_of_docs: sorted_scores_and_documents = sorted( - zip(similarity_scores, documents), + zip(preds, documents), key=lambda similarity_document_tuple: # assume the last element in logits represents the `has_answer` label similarity_document_tuple[0][-1] if logits_dim >= 2 else similarity_document_tuple[0], @@ -244,7 +244,7 @@ class SentenceTransformersRanker(BaseRanker): right_idx = 0 for number in number_of_docs: right_idx = left_idx + number - grouped_predictions.append(similarity_scores[left_idx:right_idx]) + grouped_predictions.append(preds[left_idx:right_idx]) left_idx = right_idx result = [] diff --git a/test/nodes/test_ranker.py b/test/nodes/test_ranker.py index d7b8e9a19..b2b79b4f0 100644 --- a/test/nodes/test_ranker.py +++ b/test/nodes/test_ranker.py @@ -225,3 +225,11 @@ def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits): score = results[0].score precomputed_score = -3.61354 assert math.isclose(precomputed_score, score, rel_tol=0.001) + + +def test_predict_batch_returns_correct_number_of_docs(ranker): + docs = [Document(content=f"test {number}") for number in range(5)] + + assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4 + + assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4