fix: Fix predict_batch in TransformersReader for single nested Document list (#3748)

* Fix restoring of list structure

* Add tests
This commit is contained in:
bogdankostic 2022-12-29 11:48:18 +01:00 committed by GitHub
parent 136928714c
commit 594d2a10f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 3 deletions

View File

@ -153,7 +153,7 @@ class TransformersReader(BaseReader):
max_seq_len=self.max_seq_len,
doc_stride=self.doc_stride,
)
# Transformers gives different output dependiing on top_k_per_candidate and number of inputs
# Transformers gives different output depending on top_k_per_candidate and number of inputs
if isinstance(predictions, dict):
predictions = [[predictions]]
elif len(inputs) == 1:
@ -224,7 +224,7 @@ class TransformersReader(BaseReader):
# Transformers flattens lists of length 1. This restores the original list structure.
if isinstance(predictions, dict):
predictions = [[predictions]]
elif len(number_of_docs) == 1:
elif len(inputs) == 1:
predictions = [predictions]
else:
predictions = [p if isinstance(p, list) else [p] for p in predictions]

View File

@ -107,6 +107,18 @@ def test_output_batch_multiple_queries_multiple_doc_lists(reader, docs):
assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs
def test_output_batch_single_query_single_nested_doc_list(reader, docs):
prediction = reader.predict_batch(queries=["Who lives in Berlin?"], documents=[docs], top_k=5)
assert prediction is not None
assert prediction["queries"] == ["Who lives in Berlin?"]
# Expected output: List of lists answers
assert isinstance(prediction["answers"], list)
assert isinstance(prediction["answers"][0], list)
assert isinstance(prediction["answers"][0][0], Answer)
assert len(prediction["answers"]) == 1 # Predictions for 1 collections of documents
assert len(prediction["answers"][0]) == 5 # top-k of 5 for collection of docs
@pytest.mark.integration
def test_no_answer_output(no_answer_reader, docs):
no_answer_prediction = no_answer_reader.predict(query="What is the meaning of life?", documents=docs, top_k=5)

View File

@ -118,7 +118,7 @@ EVAL_LABELS = [
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
@pytest.mark.parametrize("document_store_with_docs", ["memory"], indirect=True)
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("reader", ["farm", "transformers"], indirect=True)
def test_extractive_qa_eval(reader, retriever_with_docs, tmp_path):
labels = EVAL_LABELS[:1]