From c8398eeb6dc5a7149f253e88858b674c561804c6 Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:44:34 +0200 Subject: [PATCH] test: e2e test for Extractive QA Pipeline (#5879) * e2e test for e. qa pipeline --- .../pipelines/test_extractive_qa_pipeline.py | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 e2e/preview/pipelines/test_extractive_qa_pipeline.py diff --git a/e2e/preview/pipelines/test_extractive_qa_pipeline.py b/e2e/preview/pipelines/test_extractive_qa_pipeline.py new file mode 100644 index 000000000..c25b73a05 --- /dev/null +++ b/e2e/preview/pipelines/test_extractive_qa_pipeline.py @@ -0,0 +1,53 @@ +from haystack.preview import Pipeline, Document +from haystack.preview.document_stores import MemoryDocumentStore +from haystack.preview.components.retrievers import MemoryBM25Retriever +from haystack.preview.components.readers import ExtractiveReader + + +def test_extractive_qa_pipeline(): + document_store = MemoryDocumentStore() + + documents = [ + Document(text="My name is Jean and I live in Paris."), + Document(text="My name is Mark and I live in Berlin."), + Document(text="My name is Giorgio and I live in Rome."), + ] + + document_store.write_documents(documents) + + qa_pipeline = Pipeline() + qa_pipeline.add_component(instance=MemoryBM25Retriever(document_store=document_store), name="retriever") + qa_pipeline.add_component(instance=ExtractiveReader(model_name_or_path="deepset/tinyroberta-squad2"), name="reader") + qa_pipeline.connect("retriever", "reader") + + questions = ["Who lives in Paris?", "Who lives in Berlin?", "Who lives in Rome?"] + answers_spywords = ["Jean", "Mark", "Giorgio"] + + for question, spyword, doc in zip(questions, answers_spywords, documents): + result = qa_pipeline.run({"retriever": {"query": question}, "reader": {"query": question}}) + + extracted_answers = result["reader"]["answers"] + + # we expect at least one real answer and no_answer + assert len(extracted_answers) > 1 + + # the best answer should contain the spyword + assert spyword in extracted_answers[0].data + + # no_answer + assert extracted_answers[-1].data is None + + # since these questions are easily answerable, the best answer should have higher probability than no_answer + assert extracted_answers[0].probability >= extracted_answers[-1].probability + + for answer in extracted_answers: + assert answer.query == question + + assert hasattr(answer, "probability") + assert hasattr(answer, "start") + assert hasattr(answer, "end") + + assert hasattr(answer, "document") + # the answer is extracted from the correct document + if answer.document is not None: + assert answer.document == doc