mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-24 09:20:13 +00:00
fix: Fix predict_batch
in TransformersReader
for single nested Document list (#3748)
* Fix restoring of list structure * Add tests
This commit is contained in:
parent
136928714c
commit
594d2a10f8
@ -153,7 +153,7 @@ class TransformersReader(BaseReader):
|
||||
max_seq_len=self.max_seq_len,
|
||||
doc_stride=self.doc_stride,
|
||||
)
|
||||
# Transformers gives different output dependiing on top_k_per_candidate and number of inputs
|
||||
# Transformers gives different output depending on top_k_per_candidate and number of inputs
|
||||
if isinstance(predictions, dict):
|
||||
predictions = [[predictions]]
|
||||
elif len(inputs) == 1:
|
||||
@ -224,7 +224,7 @@ class TransformersReader(BaseReader):
|
||||
# Transformers flattens lists of length 1. This restores the original list structure.
|
||||
if isinstance(predictions, dict):
|
||||
predictions = [[predictions]]
|
||||
elif len(number_of_docs) == 1:
|
||||
elif len(inputs) == 1:
|
||||
predictions = [predictions]
|
||||
else:
|
||||
predictions = [p if isinstance(p, list) else [p] for p in predictions]
|
||||
|
@ -107,6 +107,18 @@ def test_output_batch_multiple_queries_multiple_doc_lists(reader, docs):
|
||||
assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs
|
||||
|
||||
|
||||
def test_output_batch_single_query_single_nested_doc_list(reader, docs):
|
||||
prediction = reader.predict_batch(queries=["Who lives in Berlin?"], documents=[docs], top_k=5)
|
||||
assert prediction is not None
|
||||
assert prediction["queries"] == ["Who lives in Berlin?"]
|
||||
# Expected output: List of lists answers
|
||||
assert isinstance(prediction["answers"], list)
|
||||
assert isinstance(prediction["answers"][0], list)
|
||||
assert isinstance(prediction["answers"][0][0], Answer)
|
||||
assert len(prediction["answers"]) == 1 # Predictions for 1 collections of documents
|
||||
assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_no_answer_output(no_answer_reader, docs):
|
||||
no_answer_prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)
|
||||
|
@ -118,7 +118,7 @@ EVAL_LABELS = [
|
||||
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True)
|
||||
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
|
||||
labels = EVAL_LABELS[:1]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user