diff --git a/e2e/pipelines/test_extractive_qa_pipeline.py b/e2e/pipelines/test_extractive_qa_pipeline.py index 298859d12..9a5e6f675 100644 --- a/e2e/pipelines/test_extractive_qa_pipeline.py +++ b/e2e/pipelines/test_extractive_qa_pipeline.py @@ -51,15 +51,14 @@ def test_extractive_qa_pipeline(tmp_path): # no_answer assert extracted_answers[-1].data is None - # since these questions are easily answerable, the best answer should have higher probability than no_answer - assert extracted_answers[0].probability >= extracted_answers[-1].probability + # since these questions are easily answerable, the best answer should have higher score than no_answer + assert extracted_answers[0].score >= extracted_answers[-1].score for answer in extracted_answers: assert answer.query == question - assert hasattr(answer, "probability") - assert hasattr(answer, "start") - assert hasattr(answer, "end") + assert hasattr(answer, "score") + assert hasattr(answer, "document_offset") assert hasattr(answer, "document") # the answer is extracted from the correct document diff --git a/e2e/pipelines/test_rag_pipelines.py b/e2e/pipelines/test_rag_pipelines.py index 9ab11af96..556e3ce02 100644 --- a/e2e/pipelines/test_rag_pipelines.py +++ b/e2e/pipelines/test_rag_pipelines.py @@ -75,7 +75,7 @@ def test_bm25_rag_pipeline(tmp_path): assert spyword in generated_answer.data assert generated_answer.query == question assert hasattr(generated_answer, "documents") - assert hasattr(generated_answer, "metadata") + assert hasattr(generated_answer, "meta") @pytest.mark.skipif( @@ -156,4 +156,4 @@ def test_embedding_retrieval_rag_pipeline(tmp_path): assert spyword in generated_answer.data assert generated_answer.query == question assert hasattr(generated_answer, "documents") - assert hasattr(generated_answer, "metadata") + assert hasattr(generated_answer, "meta") diff --git a/haystack/components/builders/answer_builder.py b/haystack/components/builders/answer_builder.py index 1da053511..11735e269 100644 --- a/haystack/components/builders/answer_builder.py +++ b/haystack/components/builders/answer_builder.py @@ -102,7 +102,7 @@ class AnswerBuilder: logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1) answer_string = AnswerBuilder._extract_answer_string(reply, pattern) - answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta) + answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, meta=meta) all_answers.append(answer) return {"answers": all_answers} diff --git a/haystack/components/readers/extractive.py b/haystack/components/readers/extractive.py index 86291c415..deb400050 100644 --- a/haystack/components/readers/extractive.py +++ b/haystack/components/readers/extractive.py @@ -273,44 +273,42 @@ class ExtractiveReader: logit introduced with SQuAD 2. Instead, it just computes the probability that the answer does not exist in the top k or top p. """ - flat_answers_without_queries = [] + answers_without_query = [] for document_id, start_candidates_, end_candidates_, probabilities_ in zip( document_ids, start, end, probabilities ): for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_): doc = flattened_documents[document_id] - # doc.content cannot be None, because those documents are filtered when preprocessing. - # However, mypy doesn't know that. - flat_answers_without_queries.append( - { - "data": doc.content[start_:end_], # type: ignore - "document": doc, - "probability": probability.item(), - "start": start_, - "end": end_, - "metadata": {}, - } + answers_without_query.append( + ExtractedAnswer( + query="", # Can't be None but we'll add it later + data=doc.content[start_:end_], # type: ignore + document=doc, + score=probability.item(), + document_offset=ExtractedAnswer.Span(start_, end_), + meta={}, + ) ) i = 0 nested_answers = [] for query_id in range(query_ids[-1] + 1): current_answers = [] - while i < len(flat_answers_without_queries) and query_ids[i // answers_per_seq] == query_id: - answer = flat_answers_without_queries[i] - answer["query"] = queries[query_id] - current_answers.append(ExtractedAnswer(**answer)) + while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id: + answer = answers_without_query[i] + answer.query = queries[query_id] + current_answers.append(answer) i += 1 - current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True) + current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) current_answers = current_answers[:top_k] if no_answer: - no_answer_probability = math.prod(1 - answer.probability for answer in current_answers) + no_answer_score = math.prod(1 - answer.score for answer in current_answers) answer_ = ExtractedAnswer( - data=None, query=queries[query_id], metadata={}, document=None, probability=no_answer_probability + data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score ) current_answers.append(answer_) - current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True) + current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True) if confidence_threshold is not None: - current_answers = [answer for answer in current_answers if answer.probability >= confidence_threshold] + current_answers = [answer for answer in current_answers if answer.score >= confidence_threshold] nested_answers.append(current_answers) return nested_answers diff --git a/haystack/dataclasses/answer.py b/haystack/dataclasses/answer.py index eaa30316c..1d197a74e 100644 --- a/haystack/dataclasses/answer.py +++ b/haystack/dataclasses/answer.py @@ -1,25 +1,139 @@ -from typing import Any, Dict, List, Optional -from dataclasses import dataclass +import io +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from dataclasses import dataclass, field, asdict + +from pandas import DataFrame, read_json + +from haystack.core.serialization import default_from_dict, default_to_dict from haystack.dataclasses.document import Document -@dataclass(frozen=True) -class Answer: +@runtime_checkable +@dataclass +class Answer(Protocol): data: Any query: str - metadata: Dict[str, Any] + meta: Dict[str, Any] + + def to_dict(self) -> Dict[str, Any]: + ... + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "Answer": + ... -@dataclass(frozen=True) -class ExtractedAnswer(Answer): - data: Optional[str] - document: Optional[Document] - probability: float - start: Optional[int] = None - end: Optional[int] = None +@dataclass +class ExtractedAnswer: + query: str + score: float + data: Optional[str] = None + document: Optional[Document] = None + context: Optional[str] = None + document_offset: Optional["Span"] = None + context_offset: Optional["Span"] = None + meta: Dict[str, Any] = field(default_factory=dict) + + @dataclass + class Span: + start: int + end: int + + def to_dict(self) -> Dict[str, Any]: + document = self.document.to_dict(flatten=False) if self.document is not None else None + document_offset = asdict(self.document_offset) if self.document_offset is not None else None + context_offset = asdict(self.context_offset) if self.context_offset is not None else None + return default_to_dict( + self, + data=self.data, + query=self.query, + document=document, + context=self.context, + score=self.score, + document_offset=document_offset, + context_offset=context_offset, + meta=self.meta, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExtractedAnswer": + init_params = data.get("init_parameters", {}) + if (doc := init_params.get("document")) is not None: + data["init_parameters"]["document"] = Document.from_dict(doc) + + if (offset := init_params.get("document_offset")) is not None: + data["init_parameters"]["document_offset"] = ExtractedAnswer.Span(**offset) + + if (offset := init_params.get("context_offset")) is not None: + data["init_parameters"]["context_offset"] = ExtractedAnswer.Span(**offset) + return default_from_dict(cls, data) -@dataclass(frozen=True) -class GeneratedAnswer(Answer): +@dataclass +class ExtractedTableAnswer: + query: str + score: float + data: Optional[str] = None + document: Optional[Document] = None + context: Optional[DataFrame] = None + document_cells: List["Cell"] = field(default_factory=list) + context_cells: List["Cell"] = field(default_factory=list) + meta: Dict[str, Any] = field(default_factory=dict) + + @dataclass + class Cell: + row: int + column: int + + def to_dict(self) -> Dict[str, Any]: + document = self.document.to_dict(flatten=False) if self.document is not None else None + context = self.context.to_json() if self.context is not None else None + document_cells = [asdict(c) for c in self.document_cells] + context_cells = [asdict(c) for c in self.context_cells] + return default_to_dict( + self, + data=self.data, + query=self.query, + document=document, + context=context, + score=self.score, + document_cells=document_cells, + context_cells=context_cells, + meta=self.meta, + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ExtractedTableAnswer": + init_params = data.get("init_parameters", {}) + if (doc := init_params.get("document")) is not None: + data["init_parameters"]["document"] = Document.from_dict(doc) + + if (context := init_params.get("context")) is not None: + data["init_parameters"]["context"] = read_json(io.StringIO(context)) + + if (cells := init_params.get("document_cells")) is not None: + data["init_parameters"]["document_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells] + + if (cells := init_params.get("context_cells")) is not None: + data["init_parameters"]["context_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells] + return default_from_dict(cls, data) + + +@dataclass +class GeneratedAnswer: data: str + query: str documents: List[Document] + meta: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + documents = [doc.to_dict(flatten=False) for doc in self.documents] + return default_to_dict(self, data=self.data, query=self.query, documents=documents, meta=self.meta) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "GeneratedAnswer": + init_params = data.get("init_parameters", {}) + if (documents := init_params.get("documents")) is not None: + data["init_parameters"]["documents"] = [Document.from_dict(d) for d in documents] + + return default_from_dict(cls, data) diff --git a/releasenotes/notes/answer-refactoring-b617afa946311ac8.yaml b/releasenotes/notes/answer-refactoring-b617afa946311ac8.yaml new file mode 100644 index 000000000..1c2f60980 --- /dev/null +++ b/releasenotes/notes/answer-refactoring-b617afa946311ac8.yaml @@ -0,0 +1,8 @@ +--- +enhancements: + - | + Refactor `Answer` dataclass and classes that inherited it. + Now `Answer` is a Protocol, classes that used to inherit it now respect that interface. + We also added a new `ExtractiveTableAnswer` to be used for table question answering. + + All classes now are easily serializable using `to_dict()` and `from_dict()` like `Document` and components. diff --git a/test/components/builders/test_answer_builder.py b/test/components/builders/test_answer_builder.py index 645fb33d1..10ec43ba6 100644 --- a/test/components/builders/test_answer_builder.py +++ b/test/components/builders/test_answer_builder.py @@ -17,7 +17,7 @@ class TestAnswerBuilder: output = component.run(query="query", replies=["reply1"]) answers = output["answers"] assert answers[0].data == "reply1" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -27,7 +27,7 @@ class TestAnswerBuilder: output = component.run(query="query", replies=["reply1"], metadata=[]) answers = output["answers"] assert answers[0].data == "reply1" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -38,7 +38,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "Answer: AnswerString" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -49,7 +49,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "AnswerString" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -60,7 +60,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "'AnswerString'" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -77,7 +77,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "AnswerString" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert answers[0].documents == [] assert isinstance(answers[0], GeneratedAnswer) @@ -93,7 +93,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "Answer: AnswerString" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert len(answers[0].documents) == 2 assert answers[0].documents[0].content == "test doc 1" @@ -110,7 +110,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "Answer: AnswerString[2]" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert len(answers[0].documents) == 1 assert answers[0].documents[0].content == "test doc 2" @@ -127,7 +127,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "Answer: AnswerString[3]" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert len(answers[0].documents) == 0 assert "Document index '3' referenced in Generator output is out of range." in caplog.text @@ -144,7 +144,7 @@ class TestAnswerBuilder: answers = output["answers"] assert len(answers) == 1 assert answers[0].data == "Answer: AnswerString[2][3]" - assert answers[0].metadata == {} + assert answers[0].meta == {} assert answers[0].query == "test query" assert len(answers[0].documents) == 2 assert answers[0].documents[0].content == "test doc 2" diff --git a/test/components/readers/test_extractive.py b/test/components/readers/test_extractive.py index a48c984c4..2ced6678e 100644 --- a/test/components/readers/test_extractive.py +++ b/test/components/readers/test_extractive.py @@ -140,15 +140,15 @@ def test_output(mock_reader: ExtractiveReader): doc_ids = set() no_answer_prob = 1 for doc, answer in zip(example_documents[0], answers[:3]): - assert answer.start == 11 - assert answer.end == 16 + assert answer.document_offset.start == 11 + assert answer.document_offset.end == 16 assert doc.content is not None assert answer.data == doc.content[11:16] - assert answer.probability == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor))) - no_answer_prob *= 1 - answer.probability + assert answer.score == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor))) + no_answer_prob *= 1 - answer.score doc_ids.add(doc.id) assert len(doc_ids) == 3 - assert answers[-1].probability == pytest.approx(no_answer_prob) + assert answers[-1].score == pytest.approx(no_answer_prob) def test_flatten_documents(mock_reader: ExtractiveReader): @@ -241,14 +241,14 @@ def test_nest_answers(mock_reader: ExtractiveReader): example_queries, nested_answers, expected_no_answers, [probabilities[:3, -1], probabilities[3:, -1]] ): assert len(answers) == 4 - for doc, answer, probability in zip(example_documents[0], reversed(answers[:3]), probabilities): + for doc, answer, score in zip(example_documents[0], reversed(answers[:3]), probabilities): assert answer.query == query assert answer.document == doc - assert answer.probability == pytest.approx(probability) + assert answer.score == pytest.approx(score) no_answer = answers[-1] assert no_answer.query == query assert no_answer.document is None - assert no_answer.probability == pytest.approx(expected_no_answer) + assert no_answer.score == pytest.approx(expected_no_answer) @patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained") @@ -269,19 +269,19 @@ def test_t5(): "answers" ] # remove indices when batching support is reintroduced assert answers[0].data == "Angela Merkel" - assert answers[0].probability == pytest.approx(0.7764519453048706) + assert answers[0].score == pytest.approx(0.7764519453048706) assert answers[1].data == "Olaf Scholz" - assert answers[1].probability == pytest.approx(0.7703777551651001) + assert answers[1].score == pytest.approx(0.7703777551651001) assert answers[2].data is None - assert answers[2].probability == pytest.approx(0.051331606147570596) + assert answers[2].score == pytest.approx(0.051331606147570596) # Uncomment assertions below when batching is reintroduced - # assert answers[0][2].probability == pytest.approx(0.051331606147570596) + # assert answers[0][2].score == pytest.approx(0.051331606147570596) # assert answers[1][0].data == "Jerry" - # assert answers[1][0].probability == pytest.approx(0.7413333654403687) + # assert answers[1][0].score == pytest.approx(0.7413333654403687) # assert answers[1][1].data == "Olaf Scholz" - # assert answers[1][1].probability == pytest.approx(0.7266613841056824) + # assert answers[1][1].score == pytest.approx(0.7266613841056824) # assert answers[1][2].data is None - # assert answers[1][2].probability == pytest.approx(0.0707035798685709) + # assert answers[1][2].score == pytest.approx(0.0707035798685709) @pytest.mark.integration @@ -292,24 +292,24 @@ def test_roberta(): "answers" ] # remove indices when batching is reintroduced assert answers[0].data == "Olaf Scholz" - assert answers[0].probability == pytest.approx(0.8614975214004517) + assert answers[0].score == pytest.approx(0.8614975214004517) assert answers[1].data == "Angela Merkel" - assert answers[1].probability == pytest.approx(0.857952892780304) + assert answers[1].score == pytest.approx(0.857952892780304) assert answers[2].data is None - assert answers[2].probability == pytest.approx(0.019673851661650588) + assert answers[2].score == pytest.approx(0.019673851661650588) # uncomment assertions below when there is batching in v2 # assert answers[0][0].data == "Olaf Scholz" - # assert answers[0][0].probability == pytest.approx(0.8614975214004517) + # assert answers[0][0].score == pytest.approx(0.8614975214004517) # assert answers[0][1].data == "Angela Merkel" - # assert answers[0][1].probability == pytest.approx(0.857952892780304) + # assert answers[0][1].score == pytest.approx(0.857952892780304) # assert answers[0][2].data is None - # assert answers[0][2].probability == pytest.approx(0.0196738764278237) + # assert answers[0][2].score == pytest.approx(0.0196738764278237) # assert answers[1][0].data == "Jerry" - # assert answers[1][0].probability == pytest.approx(0.7048940658569336) + # assert answers[1][0].score == pytest.approx(0.7048940658569336) # assert answers[1][1].data == "Olaf Scholz" - # assert answers[1][1].probability == pytest.approx(0.6604189872741699) + # assert answers[1][1].score == pytest.approx(0.6604189872741699) # assert answers[1][2].data is None - # assert answers[1][2].probability == pytest.approx(0.1002123719777046) + # assert answers[1][2].score == pytest.approx(0.1002123719777046) @pytest.mark.integration @@ -329,6 +329,6 @@ def test_matches_hf_pipeline(): ) # We need to disable HF postprocessing features to make the results comparable. This is related to https://github.com/huggingface/transformers/issues/26286 assert len(answers) == len(answers_hf) == 20 for answer, answer_hf in zip(answers, answers_hf): - assert answer.start == answer_hf["start"] - assert answer.end == answer_hf["end"] + assert answer.document_offset.start == answer_hf["start"] + assert answer.document_offset.end == answer_hf["end"] assert answer.data == answer_hf["answer"] diff --git a/test/dataclasses/test_answer.py b/test/dataclasses/test_answer.py new file mode 100644 index 000000000..ee2e64054 --- /dev/null +++ b/test/dataclasses/test_answer.py @@ -0,0 +1,268 @@ +from pandas import DataFrame + +from haystack.dataclasses.answer import Answer, ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer +from haystack.dataclasses.document import Document + + +class TestExtractedAnswer: + def test_init(self): + answer = ExtractedAnswer( + data="42", + query="What is the answer?", + document=Document(content="I thought a lot about this. The answer is 42."), + context="The answer is 42.", + score=1.0, + document_offset=ExtractedAnswer.Span(42, 44), + context_offset=ExtractedAnswer.Span(14, 16), + meta={"meta_key": "meta_value"}, + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.document == Document(content="I thought a lot about this. The answer is 42.") + assert answer.context == "The answer is 42." + assert answer.score == 1.0 + assert answer.document_offset == ExtractedAnswer.Span(42, 44) + assert answer.context_offset == ExtractedAnswer.Span(14, 16) + assert answer.meta == {"meta_key": "meta_value"} + + def test_protocol(self): + answer = ExtractedAnswer( + data="42", + query="What is the answer?", + document=Document(content="I thought a lot about this. The answer is 42."), + context="The answer is 42.", + score=1.0, + document_offset=ExtractedAnswer.Span(42, 44), + context_offset=ExtractedAnswer.Span(14, 16), + meta={"meta_key": "meta_value"}, + ) + assert isinstance(answer, Answer) + + def test_to_dict(self): + document = Document(content="I thought a lot about this. The answer is 42.") + answer = ExtractedAnswer( + data="42", + query="What is the answer?", + document=document, + context="The answer is 42.", + score=1.0, + document_offset=ExtractedAnswer.Span(42, 44), + context_offset=ExtractedAnswer.Span(14, 16), + meta={"meta_key": "meta_value"}, + ) + assert answer.to_dict() == { + "type": "haystack.dataclasses.answer.ExtractedAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "document": document.to_dict(flatten=False), + "context": "The answer is 42.", + "score": 1.0, + "document_offset": {"start": 42, "end": 44}, + "context_offset": {"start": 14, "end": 16}, + "meta": {"meta_key": "meta_value"}, + }, + } + + def test_from_dict(self): + answer = ExtractedAnswer.from_dict( + { + "type": "haystack.dataclasses.answer.ExtractedAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "document": { + "id": "8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f", + "content": "I thought a lot about this. The answer is 42.", + }, + "context": "The answer is 42.", + "score": 1.0, + "document_offset": {"start": 42, "end": 44}, + "context_offset": {"start": 14, "end": 16}, + "meta": {"meta_key": "meta_value"}, + }, + } + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.document == Document( + id="8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f", + content="I thought a lot about this. The answer is 42.", + ) + assert answer.context == "The answer is 42." + assert answer.score == 1.0 + assert answer.document_offset == ExtractedAnswer.Span(42, 44) + assert answer.context_offset == ExtractedAnswer.Span(14, 16) + assert answer.meta == {"meta_key": "meta_value"} + + +class TestExtractedTableAnswer: + def test_init(self): + answer = ExtractedTableAnswer( + data="42", + query="What is the answer?", + document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})), + context=DataFrame({"col3": [5, 42]}), + score=1.0, + document_cells=[ExtractedTableAnswer.Cell(1, 2)], + context_cells=[ExtractedTableAnswer.Cell(1, 0)], + meta={"meta_key": "meta_value"}, + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.document == Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})) + assert answer.context.equals(DataFrame({"col3": [5, 42]})) + assert answer.score == 1.0 + assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)] + assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)] + assert answer.meta == {"meta_key": "meta_value"} + + def test_protocol(self): + answer = ExtractedTableAnswer( + data="42", + query="What is the answer?", + document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})), + context=DataFrame({"col3": [5, 42]}), + score=1.0, + document_cells=[ExtractedTableAnswer.Cell(1, 2)], + context_cells=[ExtractedTableAnswer.Cell(1, 0)], + meta={"meta_key": "meta_value"}, + ) + assert isinstance(answer, Answer) + + def test_to_dict(self): + document = Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})) + answer = ExtractedTableAnswer( + data="42", + query="What is the answer?", + document=document, + context=DataFrame({"col3": [5, 42]}), + score=1.0, + document_cells=[ExtractedTableAnswer.Cell(1, 2)], + context_cells=[ExtractedTableAnswer.Cell(1, 0)], + meta={"meta_key": "meta_value"}, + ) + + assert answer.to_dict() == { + "type": "haystack.dataclasses.answer.ExtractedTableAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "document": document.to_dict(flatten=False), + "context": DataFrame({"col3": [5, 42]}).to_json(), + "score": 1.0, + "document_cells": [{"row": 1, "column": 2}], + "context_cells": [{"row": 1, "column": 0}], + "meta": {"meta_key": "meta_value"}, + }, + } + + def test_from_dict(self): + answer = ExtractedTableAnswer.from_dict( + { + "type": "haystack.dataclasses.answer.ExtractedTableAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "document": { + "id": "3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed", + "dataframe": '{"col1":{"0":1,"1":2},"col2":{"0":3,"1":4},"col3":{"0":5,"1":42}}', + }, + "context": '{"col3":{"0":5,"1":42}}', + "score": 1.0, + "document_cells": [{"row": 1, "column": 2}], + "context_cells": [{"row": 1, "column": 0}], + "meta": {"meta_key": "meta_value"}, + }, + } + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.document == Document( + id="3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed", + dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}), + ) + assert answer.context.equals(DataFrame({"col3": [5, 42]})) + assert answer.score == 1.0 + assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)] + assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)] + assert answer.meta == {"meta_key": "meta_value"} + + +class TestGeneratedAnswer: + def test_init(self): + answer = GeneratedAnswer( + data="42", + query="What is the answer?", + documents=[ + Document(id="1", content="The answer is 42."), + Document(id="2", content="I believe the answer is 42."), + Document(id="3", content="42 is definitely the answer."), + ], + meta={"meta_key": "meta_value"}, + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.documents == [ + Document(id="1", content="The answer is 42."), + Document(id="2", content="I believe the answer is 42."), + Document(id="3", content="42 is definitely the answer."), + ] + assert answer.meta == {"meta_key": "meta_value"} + + def test_protocol(self): + answer = GeneratedAnswer( + data="42", + query="What is the answer?", + documents=[ + Document(id="1", content="The answer is 42."), + Document(id="2", content="I believe the answer is 42."), + Document(id="3", content="42 is definitely the answer."), + ], + meta={"meta_key": "meta_value"}, + ) + assert isinstance(answer, Answer) + + def test_to_dict(self): + documents = [ + Document(id="1", content="The answer is 42."), + Document(id="2", content="I believe the answer is 42."), + Document(id="3", content="42 is definitely the answer."), + ] + answer = GeneratedAnswer( + data="42", query="What is the answer?", documents=documents, meta={"meta_key": "meta_value"} + ) + assert answer.to_dict() == { + "type": "haystack.dataclasses.answer.GeneratedAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "documents": [d.to_dict(flatten=False) for d in documents], + "meta": {"meta_key": "meta_value"}, + }, + } + + def test_from_dict(self): + answer = GeneratedAnswer.from_dict( + { + "type": "haystack.dataclasses.answer.GeneratedAnswer", + "init_parameters": { + "data": "42", + "query": "What is the answer?", + "documents": [ + {"id": "1", "content": "The answer is 42."}, + {"id": "2", "content": "I believe the answer is 42."}, + {"id": "3", "content": "42 is definitely the answer."}, + ], + "meta": {"meta_key": "meta_value"}, + }, + } + ) + assert answer.data == "42" + assert answer.query == "What is the answer?" + assert answer.documents == [ + Document(id="1", content="The answer is 42."), + Document(id="2", content="I believe the answer is 42."), + Document(id="3", content="42 is definitely the answer."), + ] + assert answer.meta == {"meta_key": "meta_value"} diff --git a/test/dataclasses/test_document.py b/test/dataclasses/test_document.py index 4ff75a729..72e45a4b6 100644 --- a/test/dataclasses/test_document.py +++ b/test/dataclasses/test_document.py @@ -1,5 +1,3 @@ -from pathlib import Path - import pandas as pd import pytest