haystack/test/nodes/test_question_generator.py
Sara Zan 05c68b6624
feat: add document_store to all BaseRetriever.retrieve() and BaseRetriever.retrieve_batch() implementations (#3379)
* add document_store to retrieve()]

* mypy & pylint

* pass docstore to embedding encoders

* schemas

* mypy and pylint

* fix tfidfretriever

* pylint

* mypy

* pylint

* fix tfidf

* mypy

* pylint

* schemas

* another fix for tfidf

* fix question generation tests

* remove docstore from embedding encoder signature

* pylint

* revert accidental test changes

* Apply suggestions from code review

* check for docstore similarity function only if the docstore is present

* check for docstore similarity function only if the docstore is present
2022-10-26 15:47:06 +02:00

151 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
import pytest
from haystack.pipelines import (
QuestionAnswerGenerationPipeline,
QuestionGenerationPipeline,
RetrieverQuestionGenerationPipeline,
)
from haystack.nodes.question_generator import QuestionGenerator
from haystack.schema import Document
text = (
"The Living End are an Australian punk rockabilly band from Melbourne, formed in 1994. Since 2002, "
"the line-up consists of Chris Cheney (vocals, guitar), Scott Owen (double bass, vocals), and Andy "
"Strachan (drums). The band rose to fame in 1997 after the release of their EP Second Solution / Prisoner "
"of Society, which peaked at No. 4 on the Australian ARIA Singles Chart. They have released eight studio "
"albums, two of which reached the No. 1 spot on the ARIA Albums Chart: The Living End (October 1998) and "
"State of Emergency (February 2006). They have also achieved chart success in the U.S. and the United "
"Kingdom. The Band was nominated 27 times and won five awards at the Australian ARIA Music Awards "
'ceremonies: "Highest Selling Single" for Second Solution / Prisoner of Society (1998), "Breakthrough '
'Artist Album" and "Best Group" for The Living End (1999), as well as "Best Rock Album" for White '
"Noise (2008) and The Ending Is Just the Beginning Repeating (2011). In October 2010, their debut album "
'was listed in the book "100 Best Australian Albums". Australian musicologist Ian McFarlane described '
'the group as "one of Australias premier rock acts. By blending a range of styles (punk, rockabilly '
"and flat out rock) with great success, The Living End has managed to produce anthemic choruses and "
'memorable songs in abundance".'
)
document = Document(content=text)
query = "Living End"
keywords = [
"Australian",
"punk",
"drummer",
"Living",
"band",
"Band",
"Second",
"album",
"albums",
"dialect",
"music",
"book",
"group",
"produce",
"Music",
"Awards",
"year",
"released",
]
text_2 = (
"Berlin straddles the banks of the Spree, which flows into the Havel (a tributary of the Elbe) in the "
"western borough of Spandau. Among the city's main topographical features are the many lakes in the western "
"and southeastern boroughs formed by the Spree, Havel and Dahme, the largest of which is Lake Müggelsee. "
"Due to its location in the European Plain, Berlin is influenced by a temperate seasonal climate. About "
"one-third of the city's area is composed of forests, parks, gardens, rivers, canals and lakes. The city "
"lies in the Central German dialect area, the Berlin dialect being a variant of the Lusatian-New Marchian "
"dialects."
)
document_2 = Document(content=text_2)
keywords_2 = [
"Berlin",
"Elbe",
"Spandau",
"Spree",
"boroughs",
"lakes",
"largest",
"seasonal",
"climate",
"city",
"dialect",
"German",
]
def test_qg_pipeline(question_generator):
p = QuestionGenerationPipeline(question_generator)
result = p.run(documents=[document])
keys = list(result)
assert "generated_questions" in keys
assert len(result["generated_questions"][0]["questions"]) > 0
def test_qg_pipeline_non_default_params():
question_generator = QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg", num_queries_per_doc=2)
p = QuestionGenerationPipeline(question_generator)
result = p.run(documents=[document, document_2])
assert isinstance(result, dict)
assert "generated_questions" in result
assert "documents" in result
assert isinstance(result["generated_questions"], list)
assert isinstance(result["documents"], list)
assert len(result["generated_questions"]) == 2
assert len(result["documents"]) == 2
assert len(result["generated_questions"][0]["questions"]) == 26
assert len(result["generated_questions"][1]["questions"]) == 12
# first list of questions should be about Australian punk band
verify_questions(result["generated_questions"][0]["questions"], keywords)
# second list of questions should be about Berlin
verify_questions(result["generated_questions"][1]["questions"], keywords_2)
@pytest.mark.parametrize("split_length, num_queries_per_doc", [(50, 1), (50, 2), (50, 3), (100, 1), (100, 2), (100, 3)])
def test_qa_generator_non_default_params(split_length, num_queries_per_doc):
question_generator = QuestionGenerator(
model_name_or_path="valhalla/t5-small-e2e-qg",
split_length=split_length,
num_queries_per_doc=num_queries_per_doc,
)
questions = question_generator.generate_batch(texts=[document.content, document_2.content])
assert isinstance(questions, list)
assert len(questions) == 2
assert isinstance(questions[0], list)
assert isinstance(questions[1], list)
assert len(questions[0]) > 0
assert len(questions[1]) > 0
# first list of questions should be about Australian punk band
verify_questions(questions[0], keywords)
# second list of questions should be about Berlin
verify_questions(questions[1], keywords_2)
@pytest.mark.parametrize("retriever,document_store", [("tfidf", "memory")], indirect=True)
def test_rqg_pipeline(question_generator, retriever):
retriever.document_store.write_documents([document])
p = RetrieverQuestionGenerationPipeline(retriever, question_generator)
result = p.run(query)
keys = list(result)
assert "generated_questions" in keys
assert len(result["generated_questions"][0]["questions"]) > 0
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
def test_qag_pipeline(question_generator, reader):
p = QuestionAnswerGenerationPipeline(question_generator, reader)
results = p.run(documents=[document])
assert "queries" in results
assert "answers" in results
assert len(results["queries"]) == len(results["answers"])
assert len(results["answers"]) > 0
assert results["answers"][0][0].answer is not None
def verify_questions(questions: List[str], question_keywords: List[str]):
for q in questions:
assert any(word in q for word in question_keywords)