mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-07-25 18:00:28 +00:00

* add document_store to retrieve()] * mypy & pylint * pass docstore to embedding encoders * schemas * mypy and pylint * fix tfidfretriever * pylint * mypy * pylint * fix tfidf * mypy * pylint * schemas * another fix for tfidf * fix question generation tests * remove docstore from embedding encoder signature * pylint * revert accidental test changes * Apply suggestions from code review * check for docstore similarity function only if the docstore is present * check for docstore similarity function only if the docstore is present
151 lines
6.1 KiB
Python
151 lines
6.1 KiB
Python
from typing import List
|
||
|
||
import pytest
|
||
|
||
from haystack.pipelines import (
|
||
QuestionAnswerGenerationPipeline,
|
||
QuestionGenerationPipeline,
|
||
RetrieverQuestionGenerationPipeline,
|
||
)
|
||
from haystack.nodes.question_generator import QuestionGenerator
|
||
from haystack.schema import Document
|
||
|
||
|
||
text = (
|
||
"The Living End are an Australian punk rockabilly band from Melbourne, formed in 1994. Since 2002, "
|
||
"the line-up consists of Chris Cheney (vocals, guitar), Scott Owen (double bass, vocals), and Andy "
|
||
"Strachan (drums). The band rose to fame in 1997 after the release of their EP Second Solution / Prisoner "
|
||
"of Society, which peaked at No. 4 on the Australian ARIA Singles Chart. They have released eight studio "
|
||
"albums, two of which reached the No. 1 spot on the ARIA Albums Chart: The Living End (October 1998) and "
|
||
"State of Emergency (February 2006). They have also achieved chart success in the U.S. and the United "
|
||
"Kingdom. The Band was nominated 27 times and won five awards at the Australian ARIA Music Awards "
|
||
'ceremonies: "Highest Selling Single" for Second Solution / Prisoner of Society (1998), "Breakthrough '
|
||
'Artist – Album" and "Best Group" for The Living End (1999), as well as "Best Rock Album" for White '
|
||
"Noise (2008) and The Ending Is Just the Beginning Repeating (2011). In October 2010, their debut album "
|
||
'was listed in the book "100 Best Australian Albums". Australian musicologist Ian McFarlane described '
|
||
'the group as "one of Australia’s premier rock acts. By blending a range of styles (punk, rockabilly '
|
||
"and flat out rock) with great success, The Living End has managed to produce anthemic choruses and "
|
||
'memorable songs in abundance".'
|
||
)
|
||
document = Document(content=text)
|
||
query = "Living End"
|
||
keywords = [
|
||
"Australian",
|
||
"punk",
|
||
"drummer",
|
||
"Living",
|
||
"band",
|
||
"Band",
|
||
"Second",
|
||
"album",
|
||
"albums",
|
||
"dialect",
|
||
"music",
|
||
"book",
|
||
"group",
|
||
"produce",
|
||
"Music",
|
||
"Awards",
|
||
"year",
|
||
"released",
|
||
]
|
||
text_2 = (
|
||
"Berlin straddles the banks of the Spree, which flows into the Havel (a tributary of the Elbe) in the "
|
||
"western borough of Spandau. Among the city's main topographical features are the many lakes in the western "
|
||
"and southeastern boroughs formed by the Spree, Havel and Dahme, the largest of which is Lake Müggelsee. "
|
||
"Due to its location in the European Plain, Berlin is influenced by a temperate seasonal climate. About "
|
||
"one-third of the city's area is composed of forests, parks, gardens, rivers, canals and lakes. The city "
|
||
"lies in the Central German dialect area, the Berlin dialect being a variant of the Lusatian-New Marchian "
|
||
"dialects."
|
||
)
|
||
document_2 = Document(content=text_2)
|
||
keywords_2 = [
|
||
"Berlin",
|
||
"Elbe",
|
||
"Spandau",
|
||
"Spree",
|
||
"boroughs",
|
||
"lakes",
|
||
"largest",
|
||
"seasonal",
|
||
"climate",
|
||
"city",
|
||
"dialect",
|
||
"German",
|
||
]
|
||
|
||
|
||
def test_qg_pipeline(question_generator):
|
||
p = QuestionGenerationPipeline(question_generator)
|
||
result = p.run(documents=[document])
|
||
keys = list(result)
|
||
assert "generated_questions" in keys
|
||
assert len(result["generated_questions"][0]["questions"]) > 0
|
||
|
||
|
||
def test_qg_pipeline_non_default_params():
|
||
question_generator = QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg", num_queries_per_doc=2)
|
||
p = QuestionGenerationPipeline(question_generator)
|
||
result = p.run(documents=[document, document_2])
|
||
assert isinstance(result, dict)
|
||
assert "generated_questions" in result
|
||
assert "documents" in result
|
||
assert isinstance(result["generated_questions"], list)
|
||
assert isinstance(result["documents"], list)
|
||
assert len(result["generated_questions"]) == 2
|
||
assert len(result["documents"]) == 2
|
||
assert len(result["generated_questions"][0]["questions"]) == 26
|
||
assert len(result["generated_questions"][1]["questions"]) == 12
|
||
|
||
# first list of questions should be about Australian punk band
|
||
verify_questions(result["generated_questions"][0]["questions"], keywords)
|
||
# second list of questions should be about Berlin
|
||
verify_questions(result["generated_questions"][1]["questions"], keywords_2)
|
||
|
||
|
||
@pytest.mark.parametrize("split_length, num_queries_per_doc", [(50, 1), (50, 2), (50, 3), (100, 1), (100, 2), (100, 3)])
|
||
def test_qa_generator_non_default_params(split_length, num_queries_per_doc):
|
||
question_generator = QuestionGenerator(
|
||
model_name_or_path="valhalla/t5-small-e2e-qg",
|
||
split_length=split_length,
|
||
num_queries_per_doc=num_queries_per_doc,
|
||
)
|
||
questions = question_generator.generate_batch(texts=[document.content, document_2.content])
|
||
assert isinstance(questions, list)
|
||
assert len(questions) == 2
|
||
assert isinstance(questions[0], list)
|
||
assert isinstance(questions[1], list)
|
||
assert len(questions[0]) > 0
|
||
assert len(questions[1]) > 0
|
||
|
||
# first list of questions should be about Australian punk band
|
||
verify_questions(questions[0], keywords)
|
||
# second list of questions should be about Berlin
|
||
verify_questions(questions[1], keywords_2)
|
||
|
||
|
||
@pytest.mark.parametrize("retriever,document_store", [("tfidf", "memory")], indirect=True)
|
||
def test_rqg_pipeline(question_generator, retriever):
|
||
retriever.document_store.write_documents([document])
|
||
p = RetrieverQuestionGenerationPipeline(retriever, question_generator)
|
||
result = p.run(query)
|
||
keys = list(result)
|
||
assert "generated_questions" in keys
|
||
assert len(result["generated_questions"][0]["questions"]) > 0
|
||
|
||
|
||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||
def test_qag_pipeline(question_generator, reader):
|
||
p = QuestionAnswerGenerationPipeline(question_generator, reader)
|
||
results = p.run(documents=[document])
|
||
assert "queries" in results
|
||
assert "answers" in results
|
||
assert len(results["queries"]) == len(results["answers"])
|
||
assert len(results["answers"]) > 0
|
||
assert results["answers"][0][0].answer is not None
|
||
|
||
|
||
def verify_questions(questions: List[str], question_keywords: List[str]):
|
||
for q in questions:
|
||
assert any(word in q for word in question_keywords)
|