From 594d2a10f84d13aef495c1cfbdaf4acad730c914 Mon Sep 17 00:00:00 2001 From: bogdankostic Date: Thu, 29 Dec 2022 11:48:18 +0100 Subject: [PATCH] fix: Fix `predict_batch` in `TransformersReader` for single nested Document list (#3748) * Fix restoring of list structure * Add tests --- haystack/nodes/reader/transformers.py | 4 ++-- test/nodes/test_reader.py | 12 ++++++++++++ test/pipelines/test_eval_batch.py | 2 +- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/haystack/nodes/reader/transformers.py b/haystack/nodes/reader/transformers.py index ab0e4509c..3de52c936 100644 --- a/haystack/nodes/reader/transformers.py +++ b/haystack/nodes/reader/transformers.py @@ -153,7 +153,7 @@ class TransformersReader(BaseReader): max_seq_len=self.max_seq_len, doc_stride=self.doc_stride, ) - # Transformers gives different output dependiing on top_k_per_candidate and number of inputs + # Transformers gives different output depending on top_k_per_candidate and number of inputs if isinstance(predictions, dict): predictions = [[predictions]] elif len(inputs) == 1: @@ -224,7 +224,7 @@ class TransformersReader(BaseReader): # Transformers flattens lists of length 1. This restores the original list structure. if isinstance(predictions, dict): predictions = [[predictions]] - elif len(number_of_docs) == 1: + elif len(inputs) == 1: predictions = [predictions] else: predictions = [p if isinstance(p, list) else [p] for p in predictions] diff --git a/test/nodes/test_reader.py b/test/nodes/test_reader.py index dfd7182ff..28dfb3d74 100644 --- a/test/nodes/test_reader.py +++ b/test/nodes/test_reader.py @@ -107,6 +107,18 @@ def test_output_batch_multiple_queries_multiple_doc_lists(reader, docs): assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs +def test_output_batch_single_query_single_nested_doc_list(reader, docs): + prediction = reader.predict_batch(queries=["Who lives in Berlin?"], documents=[docs], top_k=5) + assert prediction is not None + assert prediction["queries"] == ["Who lives in Berlin?"] + # Expected output: List of lists answers + assert isinstance(prediction["answers"], list) + assert isinstance(prediction["answers"][0], list) + assert isinstance(prediction["answers"][0][0], Answer) + assert len(prediction["answers"]) == 1 # Predictions for 1 collections of documents + assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs + + @pytest.mark.integration def test_no_answer_output(no_answer_reader, docs): no_answer_prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5) diff --git a/test/pipelines/test_eval_batch.py b/test/pipelines/test_eval_batch.py index 3b48402d2..6b66aa543 100644 --- a/test/pipelines/test_eval_batch.py +++ b/test/pipelines/test_eval_batch.py @@ -118,7 +118,7 @@ EVAL_LABELS = [ @pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) @pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True) -@pytest.mark.parametrize("reader", ["farm"], indirect=True) +@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True) def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path): labels = EVAL_LABELS[:1]