chore: Rename AnswersBuilder to AnswerBuilder (#5720)

* Add AnswersBuilder

* Add tests for AnswersBuilder

* Add release note

* PR feedback

* Fix mypy

* Remove redundant check for number of groups

* Rename AnswersBuilder to AnswerBuilder

* Update test/preview/components/builders/test_answer_builder.py

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>

* Rename reno file

---------

Co-authored-by: Daria Fokina <daria.fokina@deepset.ai>
This commit is contained in:
bogdankostic 2023-09-05 14:34:22 +02:00 committed by GitHub
parent 2acc41ea85
commit 639f7cf888
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 28 additions and 28 deletions

View File

@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
@component
class AnswersBuilder:
class AnswerBuilder:
"""
A component to parse the output of a Generator to `Answer` objects using regular expressions.
"""
@ -32,7 +32,7 @@ class AnswersBuilder:
Default: `None`.
"""
if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)
AnswerBuilder._check_num_groups_in_regex(pattern)
self.pattern = pattern
self.reference_pattern = reference_pattern
@ -80,7 +80,7 @@ class AnswersBuilder:
)
if pattern:
AnswersBuilder._check_num_groups_in_regex(pattern)
AnswerBuilder._check_num_groups_in_regex(pattern)
documents = documents or []
pattern = pattern or self.pattern
@ -90,10 +90,10 @@ class AnswersBuilder:
for i, (query, reply_list, meta_list) in enumerate(zip(queries, replies, metadata)):
doc_list = documents[i] if i < len(documents) else []
extracted_answer_strings = AnswersBuilder._extract_answer_strings(reply_list, pattern)
extracted_answer_strings = AnswerBuilder._extract_answer_strings(reply_list, pattern)
if doc_list and reference_pattern:
reference_idxs = AnswersBuilder._extract_reference_idxs(reply_list, reference_pattern)
reference_idxs = AnswerBuilder._extract_reference_idxs(reply_list, reference_pattern)
else:
reference_idxs = [[doc_idx for doc_idx, _ in enumerate(doc_list)] for _ in reply_list]
@ -120,7 +120,7 @@ class AnswersBuilder:
return default_to_dict(self, pattern=self.pattern, reference_pattern=self.reference_pattern)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "AnswersBuilder":
def from_dict(cls, data: Dict[str, Any]) -> "AnswerBuilder":
"""
Deserialize this component from a dictionary.
"""

View File

@ -0,0 +1,4 @@
---
preview:
- |
Add the `AnswerBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators.

View File

@ -1,4 +0,0 @@
---
preview:
- |
Add the `AnswersBuilder` component for Haystack 2.0 that creates Answer objects from the string output of Generators.

View File

@ -3,43 +3,43 @@ import logging
import pytest
from haystack.preview import GeneratedAnswer, Document
from haystack.preview.components.builders.answers_builder import AnswersBuilder
from haystack.preview.components.builders.answer_builder import AnswerBuilder
class TestAnswersBuilder:
class TestAnswerBuilder:
@pytest.mark.unit
def test_to_dict(self):
component = AnswersBuilder()
component = AnswerBuilder()
data = component.to_dict()
assert data == {"type": "AnswersBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
assert data == {"type": "AnswerBuilder", "init_parameters": {"pattern": None, "reference_pattern": None}}
@pytest.mark.unit
def test_to_dict_with_custom_init_parameters(self):
component = AnswersBuilder(pattern="pattern", reference_pattern="reference_pattern")
component = AnswerBuilder(pattern="pattern", reference_pattern="reference_pattern")
data = component.to_dict()
assert data == {
"type": "AnswersBuilder",
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
@pytest.mark.unit
def test_from_dict(self):
data = {
"type": "AnswersBuilder",
"type": "AnswerBuilder",
"init_parameters": {"pattern": "pattern", "reference_pattern": "reference_pattern"},
}
component = AnswersBuilder.from_dict(data)
component = AnswerBuilder.from_dict(data)
assert component.pattern == "pattern"
assert component.reference_pattern == "reference_pattern"
@pytest.mark.unit
def test_run_unmatching_input_len(self):
component = AnswersBuilder()
component = AnswerBuilder()
with pytest.raises(ValueError):
component.run(queries=["query"], replies=[["reply1"], ["reply2"]], metadata=[[]])
def test_run_without_pattern(self):
component = AnswersBuilder()
component = AnswerBuilder()
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
@ -50,7 +50,7 @@ class TestAnswersBuilder:
assert isinstance(answers[0][0], GeneratedAnswer)
def test_run_with_pattern_with_capturing_group(self):
component = AnswersBuilder(pattern=r"Answer: (.*)")
component = AnswerBuilder(pattern=r"Answer: (.*)")
answers = component.run(queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
@ -61,7 +61,7 @@ class TestAnswersBuilder:
assert isinstance(answers[0][0], GeneratedAnswer)
def test_run_with_pattern_without_capturing_group(self):
component = AnswersBuilder(pattern=r"'.*'")
component = AnswerBuilder(pattern=r"'.*'")
answers = component.run(queries=["test query"], replies=[["Answer: 'AnswerString'"]], metadata=[[{}]])
assert len(answers) == 1
assert len(answers[0]) == 1
@ -73,10 +73,10 @@ class TestAnswersBuilder:
def test_run_with_pattern_with_more_than_one_capturing_group(self):
with pytest.raises(ValueError, match="contains multiple capture groups"):
component = AnswersBuilder(pattern=r"Answer: (.*), (.*)")
component = AnswerBuilder(pattern=r"Answer: (.*), (.*)")
def test_run_with_pattern_set_at_runtime(self):
component = AnswersBuilder(pattern="unused pattern")
component = AnswerBuilder(pattern="unused pattern")
answers = component.run(
queries=["test query"], replies=[["Answer: AnswerString"]], metadata=[[{}]], pattern=r"Answer: (.*)"
)
@ -89,7 +89,7 @@ class TestAnswersBuilder:
assert isinstance(answers[0][0], GeneratedAnswer)
def test_run_with_documents_without_reference_pattern(self):
component = AnswersBuilder()
component = AnswerBuilder()
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString"]],
@ -106,7 +106,7 @@ class TestAnswersBuilder:
assert answers[0][0].documents[1].content == "test doc 2"
def test_run_with_documents_with_reference_pattern(self):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2]"]],
@ -122,7 +122,7 @@ class TestAnswersBuilder:
assert answers[0][0].documents[0].content == "test doc 2"
def test_run_with_documents_with_reference_pattern_and_no_match(self, caplog):
component = AnswersBuilder(reference_pattern="\\[(\\d+)\\]")
component = AnswerBuilder(reference_pattern="\\[(\\d+)\\]")
with caplog.at_level(logging.WARNING):
answers = component.run(
queries=["test query"],
@ -139,7 +139,7 @@ class TestAnswersBuilder:
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
def test_run_with_reference_pattern_set_at_runtime(self):
component = AnswersBuilder(reference_pattern="unused pattern")
component = AnswerBuilder(reference_pattern="unused pattern")
answers = component.run(
queries=["test query"],
replies=[["Answer: AnswerString[2][3]"]],