feat: make AnswerBuilder non batch (#5766)

* make answerbuilder non batch

* fix mypy

* review feedback

* mypy

---------

Co-authored-by: bogdankostic <bogdankostic@web.de>
This commit is contained in:
ZanSara 2023-09-13 11:01:16 +01:00 committed by GitHub
parent 784034ffc3
commit 335a09bc1d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 103 additions and 131 deletions

View File

@ -37,23 +37,23 @@ class AnswerBuilder:
self.pattern = pattern
self.reference_pattern = reference_pattern
@component.output_types(answers=List[List[GeneratedAnswer]])
@component.output_types(answers=List[GeneratedAnswer])
def run(
self,
queries: List[str],
replies: List[List[str]],
metadata: List[List[Dict[str, Any]]],
documents: Optional[List[List[Document]]] = None,
query: str,
replies: List[str],
metadata: List[Dict[str, Any]],
documents: Optional[List[Document]] = None,
pattern: Optional[str] = None,
reference_pattern: Optional[str] = None,
):
"""
Parse the output of a Generator to `Answer` objects using regular expressions.
:param queries: The queries used in the prompts for the Generator. A list of strings.
:param replies: The output of the Generator. A list of lists of strings.
:param metadata: The metadata returned by the Generator. A list of lists of dictionaries.
:param documents: The documents used as input to the Generator. A list of lists of `Document` objects. If
:param query: The query used in the prompts for the Generator. A strings.
:param replies: The output of the Generator. A list of strings.
:param metadata: The metadata returned by the Generator. A list of dictionaries.
:param documents: The documents used as input to the Generator. A list of `Document` objects. If
`documents` are specified, they are added to the `Answer` objects.
If both `documents` and `reference_pattern` are specified, the documents referenced in the
Generator output are extracted from the input documents and added to the `Answer` objects.
@ -73,43 +73,33 @@ class AnswerBuilder:
If not specified, no parsing is done, and all documents are referenced.
Default: `None`.
"""
if len(queries) != len(replies) != len(metadata):
raise ValueError(
f"Number of queries ({len(queries)}), replies ({len(replies)}), and metadata "
f"({len(metadata)}) must match."
)
if len(replies) != len(metadata):
raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.")
if pattern:
AnswerBuilder._check_num_groups_in_regex(pattern)
documents = documents or []
pattern = pattern or self.pattern
reference_pattern = reference_pattern or self.reference_pattern
all_answers = []
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)):
doc_list = documents[i] if i < len(documents) else []
for reply, meta in zip(replies, metadata):
referenced_docs = []
if documents:
reference_idxs = []
if reference_pattern:
reference_idxs = AnswerBuilder._extract_reference_idxs(reply, reference_pattern)
else:
reference_idxs = [doc_idx for doc_idx, _ in enumerate(documents)]
extracted_answer_strings = AnswerBuilder._extract_answer_strings(reply_list, pattern)
if doc_list and reference_pattern:
reference_idxs = AnswerBuilder._extract_reference_idxs(reply_list, reference_pattern)
else:
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list]
answers_for_cur_query = []
for answer_string, doc_idxs, meta in zip(extracted_answer_strings, reference_idxs, meta_list):
referenced_docs = []
for idx in doc_idxs:
if idx < len(doc_list):
referenced_docs.append(doc_list[idx])
else:
for idx in reference_idxs:
try:
referenced_docs.append(documents[idx])
except IndexError:
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1)
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
answers_for_cur_query.append(answer)
all_answers.append(answers_for_cur_query)
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
all_answers.append(answer)
return {"answers": all_answers}
@ -127,39 +117,29 @@ class AnswerBuilder:
return default_from_dict(cls, data)
@staticmethod
def _extract_answer_strings(replies: List[str], pattern: Optional[str] = None) -> List[str]:
def _extract_answer_string(reply: str, pattern: Optional[str] = None) -> str:
"""
Extract the answer strings from the generator output using the specified pattern.
Extract the answer string from the generator output using the specified pattern.
If no pattern is specified, the whole string is used as the answer.
:param replies: The output of the Generator. A list of strings.
:param replies: The output of the Generator. A string.
:param pattern: The regular expression pattern to use to extract the answer text from the generator output.
"""
if pattern is None:
return replies
return reply
extracted_answers = []
for reply in replies:
if match := re.search(pattern, reply):
# No capture group in pattern -> use the whole match as answer
if not match.lastindex:
extracted_answers.append(match.group(0))
# One capture group in pattern -> use the capture group as answer
else:
extracted_answers.append(match.group(1))
else:
extracted_answers.append("")
return extracted_answers
if match := re.search(pattern, reply):
# No capture group in pattern -> use the whole match as answer
if not match.lastindex:
return match.group(0)
# One capture group in pattern -> use the capture group as answer
return match.group(1)
return ""
@staticmethod
def _extract_reference_idxs(replies: List[str], reference_pattern: str) -> List[List[int]]:
reference_idxs = []
for reply in replies:
document_idxs = re.findall(reference_pattern, reply)
reference_idxs.append([int(idx) - 1 for idx in document_idxs])
return reference_idxs
def _extract_reference_idxs(reply: str, reference_pattern: str) -> List[int]:
document_idxs = re.findall(reference_pattern, reply)
return [int(idx) - 1 for idx in document_idxs]
@staticmethod
def _check_num_groups_in_regex(pattern: str):

View File

@ -36,130 +36,122 @@ class TestAnswerBuilder:
def test_run_unmatching_input_len(self):
component = AnswerBuilder()
with pytest.raises(ValueError):
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])
component.run(query="query", replies=["reply1", "reply2"], metadata=[])
def test_run_without_pattern(self):
component = AnswerBuilder()
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}])
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)
assert answers[0].data == "Answer: AnswerString"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_pattern_with_capturing_group(self):
component = AnswerBuilder(pattern=r"Answer: (.*)")
output = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
output = component.run(query="test query", replies=["Answer: AnswerString"], metadata=[{}])
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)
assert answers[0].data == "AnswerString"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_pattern_without_capturing_group(self):
component = AnswerBuilder(pattern=r"'.*'")
output = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
output = component.run(query="test query", replies=["Answer: 'AnswerString'"], metadata=[{}])
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "'AnswerString'"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)
assert answers[0].data == "'AnswerString'"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_pattern_with_more_than_one_capturing_group(self):
with pytest.raises(ValueError, match="contains multiple capture groups"):
component = AnswerBuilder(pattern=r"Answer: (.*), (.*)")
AnswerBuilder(pattern=r"Answer: (.*), (.*)")
def test_run_with_pattern_set_at_runtime(self):
component = AnswerBuilder(pattern="unused pattern")
output = component.run(
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
query="test query", replies=["Answer: AnswerString"], metadata=[{}], pattern=r"Answer: (.*)"
)
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert answers[0][0].documents == []
assert isinstance(answers[0][0], GeneratedAnswer)
assert answers[0].data == "AnswerString"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_with_documents_without_reference_pattern(self):
component = AnswerBuilder()
output = component.run(
queries=["test query"],
replies=[["Answer: AnswerString"]],
metadata=[[{}]],
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
query="test query",
replies=["Answer: AnswerString"],
metadata=[{}],
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
)
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 2
assert answers[0][0].documents[0].text == "test doc 1"
assert answers[0][0].documents[1].text == "test doc 2"
assert answers[0].data == "Answer: AnswerString"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 2
assert answers[0].documents[0].text == "test doc 1"
assert answers[0].documents[1].text == "test doc 2"
def test_run_with_documents_with_reference_pattern(self):
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
output = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2]"]],
metadata=[[{}]],
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
query="test query",
replies=["Answer: AnswerString[2]"],
metadata=[{}],
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
)
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[2]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 1
assert answers[0][0].documents[0].text == "test doc 2"
assert answers[0].data == "Answer: AnswerString[2]"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 1
assert answers[0].documents[0].text == "test doc 2"
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
with caplog.at_level(logging.WARNING):
output = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[3]"]],
metadata=[[{}]],
documents=[[Document(text="test doc 1"), Document(text="test doc 2")]],
query="test query",
replies=["Answer: AnswerString[3]"],
metadata=[{}],
documents=[Document(text="test doc 1"), Document(text="test doc 2")],
)
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[3]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 0
assert answers[0].data == "Answer: AnswerString[3]"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 0
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
def test_run_with_reference_pattern_set_at_runtime(self):
component = AnswerBuilder(reference_pattern="unused pattern")
output = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2][3]"]],
metadata=[[{}]],
documents=[[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")]],
query="test query",
replies=["Answer: AnswerString[2][3]"],
metadata=[{}],
documents=[Document(text="test doc 1"), Document(text="test doc 2"), Document(text="test doc 3")],
reference_pattern="\\[(\\d+)\\]",
)
answers = output["answers"]
assert len(answers) == 1
assert len(answers[0]) == 1
assert answers[0][0].data == "Answer: AnswerString[2][3]"
assert answers[0][0].metadata == {}
assert answers[0][0].query == "test query"
assert len(answers[0][0].documents) == 2
assert answers[0][0].documents[0].text == "test doc 2"
assert answers[0][0].documents[1].text == "test doc 3"
assert answers[0].data == "Answer: AnswerString[2][3]"
assert answers[0].metadata == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 2
assert answers[0].documents[0].text == "test doc 2"
assert answers[0].documents[1].text == "test doc 3"