haystack/test/test_question_generator.py
bogdankostic 738e008020
Add run_batch method to all nodes and Pipeline to allow batch querying (#2481)
* Add run_batch methods for batch querying

* Update Documentation & Code Style

* Fix mypy

* Update Documentation & Code Style

* Fix mypy

* Fix linter

* Fix tests

* Update Documentation & Code Style

* Fix tests

* Update Documentation & Code Style

* Fix mypy

* Fix rest api test

* Update Documentation & Code Style

* Add Doc strings

* Update Documentation & Code Style

* Add batch_size as attribute to nodes supporting batching

* Adapt error messages

* Adapt type of filters in retrievers

* Revert change about truncation_warning in summarizer

* Unify multiple_doc_lists tests

* Use smaller models in extractor tests

* Add return types to JoinAnswers and RouteDocuments

* Adapt return statements in reader's run_batch method

* Allow list of filters

* Adapt error messages

* Update Documentation & Code Style

* Fix tests

* Fix mypy

* Adapt print_questions

* Remove disabling warning about too many public methods

* Add flag for pylint to disable warning about too many public methods in pipelines/base.py and document_stores/base.py

* Add type check

* Update Documentation & Code Style

* Adapt tutorial 11

* Update Documentation & Code Style

* Add query_batch method for DCDocStore

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
2022-05-11 11:11:00 +02:00

43 lines
2.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from haystack.pipelines import (
QuestionAnswerGenerationPipeline,
QuestionGenerationPipeline,
RetrieverQuestionGenerationPipeline,
)
from haystack.schema import Document
import pytest
text = 'The Living End are an Australian punk rockabilly band from Melbourne, formed in 1994. Since 2002, the line-up consists of Chris Cheney (vocals, guitar), Scott Owen (double bass, vocals), and Andy Strachan (drums). The band rose to fame in 1997 after the release of their EP Second Solution / Prisoner of Society, which peaked at No. 4 on the Australian ARIA Singles Chart. They have released eight studio albums, two of which reached the No. 1 spot on the ARIA Albums Chart: The Living End (October 1998) and State of Emergency (February 2006). They have also achieved chart success in the U.S. and the United Kingdom. The Band was nominated 27 times and won five awards at the Australian ARIA Music Awards ceremonies: "Highest Selling Single" for Second Solution / Prisoner of Society (1998), "Breakthrough Artist Album" and "Best Group" for The Living End (1999), as well as "Best Rock Album" for White Noise (2008) and The Ending Is Just the Beginning Repeating (2011). In October 2010, their debut album was listed in the book "100 Best Australian Albums". Australian musicologist Ian McFarlane described the group as "one of Australias premier rock acts. By blending a range of styles (punk, rockabilly and flat out rock) with great success, The Living End has managed to produce anthemic choruses and memorable songs in abundance".'
document = Document(content=text)
query = "Living End"
def test_qg_pipeline(question_generator):
p = QuestionGenerationPipeline(question_generator)
result = p.run(documents=[document])
keys = list(result)
assert "generated_questions" in keys
assert len(result["generated_questions"][0]["questions"]) > 0
@pytest.mark.parametrize("retriever,document_store", [("tfidf", "memory")], indirect=True)
def test_rqg_pipeline(question_generator, retriever):
retriever.document_store.write_documents([document])
retriever.fit()
p = RetrieverQuestionGenerationPipeline(retriever, question_generator)
result = p.run(query)
keys = list(result)
assert "generated_questions" in keys
assert len(result["generated_questions"][0]["questions"]) > 0
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_qag_pipeline(question_generator, reader):
p = QuestionAnswerGenerationPipeline(question_generator, reader)
results = p.run(documents=[document])
assert "queries" in results
assert "answers" in results
assert len(results["queries"]) == len(results["answers"])
assert len(results["answers"]) > 0
assert results["answers"][0][0].answer is not None