fix: SentenceTransformersRanker's predict_batch returns wrong number of documents (#4756)

* Fix SentenceTransformersRanker spredict_batch returning wrong number of documents

* Julian's feedback
This commit is contained in:
Vladimir Blagojevic 2023-04-27 15:24:39 +02:00 committed by GitHub
parent c9a415ec8d
commit dcaf3002f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -225,7 +225,7 @@ class SentenceTransformersRanker(BaseRanker):
logits_dim = similarity_scores.shape[1] # [batch_size, logits_dim]
if single_list_of_docs:
sorted_scores_and_documents = sorted(
zip(similarity_scores, documents),
zip(preds, documents),
key=lambda similarity_document_tuple:
# assume the last element in logits represents the `has_answer` label
similarity_document_tuple[0][-1] if logits_dim >= 2 else similarity_document_tuple[0],
@ -244,7 +244,7 @@ class SentenceTransformersRanker(BaseRanker):
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_predictions.append(similarity_scores[left_idx:right_idx])
grouped_predictions.append(preds[left_idx:right_idx])
left_idx = right_idx
result = []

View File

@ -225,3 +225,11 @@ def test_ranker_returns_raw_score_for_two_logits(ranker_two_logits):
score = results[0].score
precomputed_score = -3.61354
assert math.isclose(precomputed_score, score, rel_tol=0.001)
def test_predict_batch_returns_correct_number_of_docs(ranker):
docs = [Document(content=f"test {number}") for number in range(5)]
assert len(ranker.predict("where is test 3?", docs, top_k=4)) == 4
assert len(ranker.predict_batch(["where is test 3?"], docs, batch_size=2, top_k=4)) == 4