feat: Make metadata optional in AnswerBuilder (#5909)

* optional metadata

* improve docstring
This commit is contained in:
ZanSara 2023-09-28 14:42:19 +02:00 committed by GitHub
parent 9340c572f9
commit 83724b74e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 4 deletions

View File

@ -42,7 +42,7 @@ class AnswerBuilder:
self,
query: str,
replies: List[str],
metadata: List[Dict[str, Any]],
metadata: Optional[List[Dict[str, Any]]] = None,
documents: Optional[List[Document]] = None,
pattern: Optional[str] = None,
reference_pattern: Optional[str] = None,
@ -52,7 +52,8 @@ class AnswerBuilder:
:param query: The query used in the prompts for the Generator. A strings.
:param replies: The output of the Generator. A list of strings.
:param metadata: The metadata returned by the Generator. A list of dictionaries.
:param metadata: The metadata returned by the Generator. An optional list of dictionaries. If not specified,
the generated answer will contain no metadata.
:param documents: The documents used as input to the Generator. A list of `Document` objects. If
`documents` are specified, they are added to the `Answer` objects.
If both `documents` and `reference_pattern` are specified, the documents referenced in the
@ -73,8 +74,11 @@ class AnswerBuilder:
If not specified, no parsing is done, and all documents are referenced.
Default: `None`.
"""
if len(replies) != len(metadata):
if not metadata:
metadata = [{}] * len(replies)
elif len(replies) != len(metadata):
raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.")
if pattern:
AnswerBuilder._check_num_groups_in_regex(pattern)

View File

@ -36,7 +36,29 @@ class TestAnswerBuilder:
def test_run_unmatching_input_len(self):
component = AnswerBuilder()
with pytest.raises(ValueError):
component.run(query="query", replies=["reply1", "reply2"], metadata=[])
component.run(query="query", replies=["reply1"], metadata=[{"test": "meta"}, {"test": "meta2"}])
@pytest.mark.unit
def test_run_without_meta(self):
component = AnswerBuilder()
output = component.run(query="query", replies=["reply1"])
answers = output["answers"]
assert answers[0].data == "reply1"
assert answers[0].metadata == {}
assert answers[0].query == "query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@pytest.mark.unit
def test_run_meta_is_an_empty_list(self):
component = AnswerBuilder()
output = component.run(query="query", replies=["reply1"], metadata=[])
answers = output["answers"]
assert answers[0].data == "reply1"
assert answers[0].metadata == {}
assert answers[0].query == "query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
def test_run_without_pattern(self):
component = AnswerBuilder()