From 83724b74e30707c1bf8ab8c9758e8e52a1c59eed Mon Sep 17 00:00:00 2001 From: ZanSara Date: Thu, 28 Sep 2023 14:42:19 +0200 Subject: [PATCH] feat: Make `metadata` optional in AnswerBuilder (#5909) * optional metadata * improve docstring --- .../components/builders/answer_builder.py | 10 +++++--- .../builders/test_answer_builder.py | 24 ++++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/haystack/preview/components/builders/answer_builder.py b/haystack/preview/components/builders/answer_builder.py index af644d22a..201d43b9c 100644 --- a/haystack/preview/components/builders/answer_builder.py +++ b/haystack/preview/components/builders/answer_builder.py @@ -42,7 +42,7 @@ class AnswerBuilder: self, query: str, replies: List[str], - metadata: List[Dict[str, Any]], + metadata: Optional[List[Dict[str, Any]]] = None, documents: Optional[List[Document]] = None, pattern: Optional[str] = None, reference_pattern: Optional[str] = None, @@ -52,7 +52,8 @@ class AnswerBuilder: :param query: The query used in the prompts for the Generator. A strings. :param replies: The output of the Generator. A list of strings. - :param metadata: The metadata returned by the Generator. A list of dictionaries. + :param metadata: The metadata returned by the Generator. An optional list of dictionaries. If not specified, + the generated answer will contain no metadata. :param documents: The documents used as input to the Generator. A list of `Document` objects. If `documents` are specified, they are added to the `Answer` objects. If both `documents` and `reference_pattern` are specified, the documents referenced in the @@ -73,8 +74,11 @@ class AnswerBuilder: If not specified, no parsing is done, and all documents are referenced. Default: `None`. """ - if len(replies) != len(metadata): + if not metadata: + metadata = [{}] * len(replies) + elif len(replies) != len(metadata): raise ValueError(f"Number of replies ({len(replies)}), and metadata ({len(metadata)}) must match.") + if pattern: AnswerBuilder._check_num_groups_in_regex(pattern) diff --git a/test/preview/components/builders/test_answer_builder.py b/test/preview/components/builders/test_answer_builder.py index 03b4a42f9..e93ae2f03 100644 --- a/test/preview/components/builders/test_answer_builder.py +++ b/test/preview/components/builders/test_answer_builder.py @@ -36,7 +36,29 @@ class TestAnswerBuilder: def test_run_unmatching_input_len(self): component = AnswerBuilder() with pytest.raises(ValueError): - component.run(query="query", replies=["reply1", "reply2"], metadata=[]) + component.run(query="query", replies=["reply1"], metadata=[{"test": "meta"}, {"test": "meta2"}]) + + @pytest.mark.unit + def test_run_without_meta(self): + component = AnswerBuilder() + output = component.run(query="query", replies=["reply1"]) + answers = output["answers"] + assert answers[0].data == "reply1" + assert answers[0].metadata == {} + assert answers[0].query == "query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) + + @pytest.mark.unit + def test_run_meta_is_an_empty_list(self): + component = AnswerBuilder() + output = component.run(query="query", replies=["reply1"], metadata=[]) + answers = output["answers"] + assert answers[0].data == "reply1" + assert answers[0].metadata == {} + assert answers[0].query == "query" + assert answers[0].documents == [] + assert isinstance(answers[0], GeneratedAnswer) def test_run_without_pattern(self): component = AnswerBuilder()