mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-08 21:28:00 +00:00
feat: make AnswerBuilder non batch (#5766)
* make answerbuilder non batch * fix mypy * review feedback * mypy --------- Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
parent
784034ffc3
commit
335a09bc1d
@ -37,23 +37,23 @@ class AnswerBuilder:
|
||||
self.pattern = pattern
|
||||
self.reference_pattern = reference_pattern
|
||||
|
||||
@component.output_types(answers=List[List[GeneratedAnswer]])
|
||||
@component.output_types(answers=List[GeneratedAnswer])
|
||||
def run(
|
||||
self,
|
||||
queries: List[str],
|
||||
replies: List[List[str]],
|
||||
metadata: List[List[Dict[str, Any]]],
|
||||
documents: Optional[List[List[Document]]] = None,
|
||||
query: str,
|
||||
replies: List[str],
|
||||
metadata: List[Dict[str, Any]],
|
||||
documents: Optional[List[Document]] = None,
|
||||
pattern: Optional[str] = None,
|
||||
reference_pattern: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Parse the output of a Generator to `Answer` objects using regular expressions.
|
||||
|
||||
:param queries: The queries used in the prompts for the Generator. A list of strings.
|
||||
:param replies: The output of the Generator. A list of lists of strings.
|
||||
:param metadata: The metadata returned by the Generator. A list of lists of dictionaries.
|
||||
:param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If
|
||||
:param query: The query used in the prompts for the Generator. A strings.
|
||||
:param replies: The output of the Generator. A list of strings.
|
||||
:param metadata: The metadata returned by the Generator. A list of dictionaries.
|
||||
:param documents: The documents used as input to the Generator. A list of `Document` objects. If
|
||||
`documents` are specified, they are added to the `Answer` objects.
|
||||
If both `documents` and `reference_pattern` are specified, the documents referenced in the
|
||||
Generator output are extracted from the input documents and added to the `Answer` objects.
|
||||
@ -73,43 +73,33 @@ class AnswerBuilder:
|
||||
If not specified, no parsing is done, and all documents are referenced.
|
||||
Default: `None`.
|
||||
"""
|
||||
if len(queries) != len(replies) != len(metadata):
|
||||
raise ValueError(
|
||||
f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata "
|
||||
f"({len(metadata)}) must match."
|
||||
)
|
||||
|
||||
if len(replies) != len(metadata):
|
||||
raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.")
|
||||
if pattern:
|
||||
AnswerBuilder._check_num_groups_in_regex(pattern)
|
||||
|
||||
documents = documents or []
|
||||
pattern = pattern or self.pattern
|
||||
reference_pattern = reference_pattern or self.reference_pattern
|
||||
|
||||
all_answers = []
|
||||
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)):
|
||||
doc_list = documents[i] if i < len(documents) else []
|
||||
for reply, meta in zip(replies, metadata):
|
||||
referenced_docs = []
|
||||
if documents:
|
||||
reference_idxs = []
|
||||
if reference_pattern:
|
||||
reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern)
|
||||
else:
|
||||
reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)]
|
||||
|
||||
extracted_answer_strings = AnswerBuilder._extract_answer_strings(reply_list, pattern)
|
||||
|
||||
if doc_list and reference_pattern:
|
||||
reference_idxs = AnswerBuilder._extract_reference_idxs(reply_list, reference_pattern)
|
||||
else:
|
||||
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list]
|
||||
|
||||
answers_for_cur_query = []
|
||||
for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list):
|
||||
referenced_docs = []
|
||||
for idx in doc_idxs:
|
||||
if idx < len(doc_list):
|
||||
referenced_docs.append(doc_list[idx])
|
||||
else:
|
||||
for idx in reference_idxs:
|
||||
try:
|
||||
referenced_docs.append(documents[idx])
|
||||
except IndexError:
|
||||
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1)
|
||||
|
||||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
|
||||
answers_for_cur_query.append(answer)
|
||||
|
||||
all_answers.append(answers_for_cur_query)
|
||||
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
|
||||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
|
||||
all_answers.append(answer)
|
||||
|
||||
return {"answers": all_answers}
|
||||
|
||||
@ -127,39 +117,29 @@ class AnswerBuilder:
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@staticmethod
|
||||
def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]:
|
||||
def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str:
|
||||
"""
|
||||
Extract the answer strings from the generator output using the specified pattern.
|
||||
Extract the answer string from the generator output using the specified pattern.
|
||||
If no pattern is specified, the whole string is used as the answer.
|
||||
|
||||
:param replies: The output of the Generator. A list of strings.
|
||||
:param replies: The output of the Generator. A string.
|
||||
:param pattern: The regular expression pattern to use to extract the answer text from the generator output.
|
||||
"""
|
||||
if pattern is None:
|
||||
return replies
|
||||
return reply
|
||||
|
||||
extracted_answers = []
|
||||
for reply in replies:
|
||||
if match := re.search(pattern, reply):
|
||||
# No capture group in pattern -> use the whole match as answer
|
||||
if not match.lastindex:
|
||||
extracted_answers.append(match.group(0))
|
||||
# One capture group in pattern -> use the capture group as answer
|
||||
else:
|
||||
extracted_answers.append(match.group(1))
|
||||
else:
|
||||
extracted_answers.append("")
|
||||
|
||||
return extracted_answers
|
||||
if match := re.search(pattern, reply):
|
||||
# No capture group in pattern -> use the whole match as answer
|
||||
if not match.lastindex:
|
||||
return match.group(0)
|
||||
# One capture group in pattern -> use the capture group as answer
|
||||
return match.group(1)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]:
|
||||
reference_idxs = []
|
||||
for reply in replies:
|
||||
document_idxs = re.findall(reference_pattern, reply)
|
||||
reference_idxs.append([int(idx) - 1 for idx in document_idxs])
|
||||
|
||||
return reference_idxs
|
||||
def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]:
|
||||
document_idxs = re.findall(reference_pattern, reply)
|
||||
return [int(idx) - 1 for idx in document_idxs]
|
||||
|
||||
@staticmethod
|
||||
def _check_num_groups_in_regex(pattern: str):
|
||||
|
||||
@ -36,130 +36,122 @@ class TestAnswerBuilder:
|
||||
def test_run_unmatching_input_len(self):
|
||||
component = AnswerBuilder()
|
||||
with pytest.raises(ValueError):
|
||||
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])
|
||||
component.run(query="query", replies=["reply1", "reply2"], metadata=[])
|
||||
|
||||
def test_run_without_pattern(self):
|
||||
component = AnswerBuilder()
|
||||
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
|
||||
output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}])
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "Answer: AnswerString"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert answers[0][0].documents == []
|
||||
assert isinstance(answers[0][0], GeneratedAnswer)
|
||||
assert answers[0].data == "Answer: AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_pattern_with_capturing_group(self):
|
||||
component = AnswerBuilder(pattern=r"Answer: (.*)")
|
||||
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
|
||||
output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}])
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "AnswerString"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert answers[0][0].documents == []
|
||||
assert isinstance(answers[0][0], GeneratedAnswer)
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_pattern_without_capturing_group(self):
|
||||
component = AnswerBuilder(pattern=r"'.*'")
|
||||
output = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
|
||||
output = component.run(query="test query", replies=["Answer: 'AnswerString'"], metadata=[{}])
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "'AnswerString'"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert answers[0][0].documents == []
|
||||
assert isinstance(answers[0][0], GeneratedAnswer)
|
||||
assert answers[0].data == "'AnswerString'"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_pattern_with_more_than_one_capturing_group(self):
|
||||
with pytest.raises(ValueError, match="contains multiple capture groups"):
|
||||
component = AnswerBuilder(pattern=r"Answer: (.*), (.*)")
|
||||
AnswerBuilder(pattern=r"Answer: (.*), (.*)")
|
||||
|
||||
def test_run_with_pattern_set_at_runtime(self):
|
||||
component = AnswerBuilder(pattern="unused pattern")
|
||||
output = component.run(
|
||||
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
|
||||
query="test query", replies=["Answer: AnswerString"], metadata=[{}], pattern=r"Answer: (.*)"
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "AnswerString"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert answers[0][0].documents == []
|
||||
assert isinstance(answers[0][0], GeneratedAnswer)
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
|
||||
def test_run_with_documents_without_reference_pattern(self):
|
||||
component = AnswerBuilder()
|
||||
output = component.run(
|
||||
queries=["test query"],
|
||||
replies=[["Answer: AnswerString"]],
|
||||
metadata=[[{}]],
|
||||
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
||||
query="test query",
|
||||
replies=["Answer: AnswerString"],
|
||||
metadata=[{}],
|
||||
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "Answer: AnswerString"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert len(answers[0][0].documents) == 2
|
||||
assert answers[0][0].documents[0].text == "test doc 1"
|
||||
assert answers[0][0].documents[1].text == "test doc 2"
|
||||
assert answers[0].data == "Answer: AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].text == "test doc 1"
|
||||
assert answers[0].documents[1].text == "test doc 2"
|
||||
|
||||
def test_run_with_documents_with_reference_pattern(self):
|
||||
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
||||
output = component.run(
|
||||
queries=["test query"],
|
||||
replies=[["Answer: AnswerString[2]"]],
|
||||
metadata=[[{}]],
|
||||
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
||||
query="test query",
|
||||
replies=["Answer: AnswerString[2]"],
|
||||
metadata=[{}],
|
||||
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "Answer: AnswerString[2]"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert len(answers[0][0].documents) == 1
|
||||
assert answers[0][0].documents[0].text == "test doc 2"
|
||||
assert answers[0].data == "Answer: AnswerString[2]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 1
|
||||
assert answers[0].documents[0].text == "test doc 2"
|
||||
|
||||
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
|
||||
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
|
||||
with caplog.at_level(logging.WARNING):
|
||||
output = component.run(
|
||||
queries=["test query"],
|
||||
replies=[["Answer: AnswerString[3]"]],
|
||||
metadata=[[{}]],
|
||||
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
|
||||
query="test query",
|
||||
replies=["Answer: AnswerString[3]"],
|
||||
metadata=[{}],
|
||||
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "Answer: AnswerString[3]"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert len(answers[0][0].documents) == 0
|
||||
assert answers[0].data == "Answer: AnswerString[3]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 0
|
||||
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
|
||||
|
||||
def test_run_with_reference_pattern_set_at_runtime(self):
|
||||
component = AnswerBuilder(reference_pattern="unused pattern")
|
||||
output = component.run(
|
||||
queries=["test query"],
|
||||
replies=[["Answer: AnswerString[2][3]"]],
|
||||
metadata=[[{}]],
|
||||
documents=[[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")]],
|
||||
query="test query",
|
||||
replies=["Answer: AnswerString[2][3]"],
|
||||
metadata=[{}],
|
||||
documents=[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")],
|
||||
reference_pattern="\\[(\\d+)\\]",
|
||||
)
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert len(answers[0]) == 1
|
||||
assert answers[0][0].data == "Answer: AnswerString[2][3]"
|
||||
assert answers[0][0].metadata == {}
|
||||
assert answers[0][0].query == "test query"
|
||||
assert len(answers[0][0].documents) == 2
|
||||
assert answers[0][0].documents[0].text == "test doc 2"
|
||||
assert answers[0][0].documents[1].text == "test doc 3"
|
||||
assert answers[0].data == "Answer: AnswerString[2][3]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].text == "test doc 2"
|
||||
assert answers[0].documents[1].text == "test doc 3"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user