refacotr: Refactor answer dataclasses (#6523)

* Refactor answer dataclasses

* Add release notes

* Fix tests

* Fix end to end tests

* Enhance ExtractiveReader
This commit is contained in:
Silvano Cerza 2023-12-11 18:50:49 +01:00 committed by GitHub
parent 820d9c37d5
commit 18dbce25fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 466 additions and 81 deletions

View File

@ -51,15 +51,14 @@ def test_extractive_qa_pipeline(tmp_path):
# no_answer
assert extracted_answers[-1].data is None
# since these questions are easily answerable, the best answer should have higher probability than no_answer
assert extracted_answers[0].probability >= extracted_answers[-1].probability
# since these questions are easily answerable, the best answer should have higher score than no_answer
assert extracted_answers[0].score >= extracted_answers[-1].score
for answer in extracted_answers:
assert answer.query == question
assert hasattr(answer, "probability")
assert hasattr(answer, "start")
assert hasattr(answer, "end")
assert hasattr(answer, "score")
assert hasattr(answer, "document_offset")
assert hasattr(answer, "document")
# the answer is extracted from the correct document

View File

@ -75,7 +75,7 @@ def test_bm25_rag_pipeline(tmp_path):
assert spyword in generated_answer.data
assert generated_answer.query == question
assert hasattr(generated_answer, "documents")
assert hasattr(generated_answer, "metadata")
assert hasattr(generated_answer, "meta")
@pytest.mark.skipif(
@ -156,4 +156,4 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
assert spyword in generated_answer.data
assert generated_answer.query == question
assert hasattr(generated_answer, "documents")
assert hasattr(generated_answer, "metadata")
assert hasattr(generated_answer, "meta")

View File

@ -102,7 +102,7 @@ class AnswerBuilder:
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1)
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, meta=meta)
all_answers.append(answer)
return {"answers": all_answers}

View File

@ -273,44 +273,42 @@ class ExtractiveReader:
logit introduced with SQuAD 2. Instead, it just computes the probability that the answer does not exist
in the top k or top p.
"""
flat_answers_without_queries = []
answers_without_query = []
for document_id, start_candidates_, end_candidates_, probabilities_ in zip(
document_ids, start, end, probabilities
):
for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_):
doc = flattened_documents[document_id]
# doc.content cannot be None, because those documents are filtered when preprocessing.
# However, mypy doesn't know that.
flat_answers_without_queries.append(
{
"data": doc.content[start_:end_], # type: ignore
"document": doc,
"probability": probability.item(),
"start": start_,
"end": end_,
"metadata": {},
}
answers_without_query.append(
ExtractedAnswer(
query="", # Can't be None but we'll add it later
data=doc.content[start_:end_], # type: ignore
document=doc,
score=probability.item(),
document_offset=ExtractedAnswer.Span(start_, end_),
meta={},
)
)
i = 0
nested_answers = []
for query_id in range(query_ids[-1] + 1):
current_answers = []
while i < len(flat_answers_without_queries) and query_ids[i // answers_per_seq] == query_id:
answer = flat_answers_without_queries[i]
answer["query"] = queries[query_id]
current_answers.append(ExtractedAnswer(**answer))
while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
answer = answers_without_query[i]
answer.query = queries[query_id]
current_answers.append(answer)
i += 1
current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True)
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
current_answers = current_answers[:top_k]
if no_answer:
no_answer_probability = math.prod(1 - answer.probability for answer in current_answers)
no_answer_score = math.prod(1 - answer.score for answer in current_answers)
answer_ = ExtractedAnswer(
data=None, query=queries[query_id], metadata={}, document=None, probability=no_answer_probability
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
)
current_answers.append(answer_)
current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True)
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
if confidence_threshold is not None:
current_answers = [answer for answer in current_answers if answer.probability >= confidence_threshold]
current_answers = [answer for answer in current_answers if answer.score >= confidence_threshold]
nested_answers.append(current_answers)
return nested_answers

View File

@ -1,25 +1,139 @@
from typing import Any, Dict, List, Optional
from dataclasses import dataclass
import io
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from dataclasses import dataclass, field, asdict
from pandas import DataFrame, read_json
from haystack.core.serialization import default_from_dict, default_to_dict
from haystack.dataclasses.document import Document
@dataclass(frozen=True)
class Answer:
@runtime_checkable
@dataclass
class Answer(Protocol):
data: Any
query: str
metadata: Dict[str, Any]
meta: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
...
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Answer":
...
@dataclass(frozen=True)
class ExtractedAnswer(Answer):
data: Optional[str]
document: Optional[Document]
probability: float
start: Optional[int] = None
end: Optional[int] = None
@dataclass
class ExtractedAnswer:
query: str
score: float
data: Optional[str] = None
document: Optional[Document] = None
context: Optional[str] = None
document_offset: Optional["Span"] = None
context_offset: Optional["Span"] = None
meta: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Span:
start: int
end: int
def to_dict(self) -> Dict[str, Any]:
document = self.document.to_dict(flatten=False) if self.document is not None else None
document_offset = asdict(self.document_offset) if self.document_offset is not None else None
context_offset = asdict(self.context_offset) if self.context_offset is not None else None
return default_to_dict(
self,
data=self.data,
query=self.query,
document=document,
context=self.context,
score=self.score,
document_offset=document_offset,
context_offset=context_offset,
meta=self.meta,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExtractedAnswer":
init_params = data.get("init_parameters", {})
if (doc := init_params.get("document")) is not None:
data["init_parameters"]["document"] = Document.from_dict(doc)
if (offset := init_params.get("document_offset")) is not None:
data["init_parameters"]["document_offset"] = ExtractedAnswer.Span(**offset)
if (offset := init_params.get("context_offset")) is not None:
data["init_parameters"]["context_offset"] = ExtractedAnswer.Span(**offset)
return default_from_dict(cls, data)
@dataclass(frozen=True)
class GeneratedAnswer(Answer):
@dataclass
class ExtractedTableAnswer:
query: str
score: float
data: Optional[str] = None
document: Optional[Document] = None
context: Optional[DataFrame] = None
document_cells: List["Cell"] = field(default_factory=list)
context_cells: List["Cell"] = field(default_factory=list)
meta: Dict[str, Any] = field(default_factory=dict)
@dataclass
class Cell:
row: int
column: int
def to_dict(self) -> Dict[str, Any]:
document = self.document.to_dict(flatten=False) if self.document is not None else None
context = self.context.to_json() if self.context is not None else None
document_cells = [asdict(c) for c in self.document_cells]
context_cells = [asdict(c) for c in self.context_cells]
return default_to_dict(
self,
data=self.data,
query=self.query,
document=document,
context=context,
score=self.score,
document_cells=document_cells,
context_cells=context_cells,
meta=self.meta,
)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ExtractedTableAnswer":
init_params = data.get("init_parameters", {})
if (doc := init_params.get("document")) is not None:
data["init_parameters"]["document"] = Document.from_dict(doc)
if (context := init_params.get("context")) is not None:
data["init_parameters"]["context"] = read_json(io.StringIO(context))
if (cells := init_params.get("document_cells")) is not None:
data["init_parameters"]["document_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells]
if (cells := init_params.get("context_cells")) is not None:
data["init_parameters"]["context_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells]
return default_from_dict(cls, data)
@dataclass
class GeneratedAnswer:
data: str
query: str
documents: List[Document]
meta: Dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]:
documents = [doc.to_dict(flatten=False) for doc in self.documents]
return default_to_dict(self, data=self.data, query=self.query, documents=documents, meta=self.meta)
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "GeneratedAnswer":
init_params = data.get("init_parameters", {})
if (documents := init_params.get("documents")) is not None:
data["init_parameters"]["documents"] = [Document.from_dict(d) for d in documents]
return default_from_dict(cls, data)

View File

@ -0,0 +1,8 @@
---
enhancements:
- |
Refactor `Answer` dataclass and classes that inherited it.
Now `Answer` is a Protocol, classes that used to inherit it now respect that interface.
We also added a new `ExtractiveTableAnswer` to be used for table question answering.
All classes now are easily serializable using `to_dict()` and `from_dict()` like `Document` and components.

View File

@ -17,7 +17,7 @@ class TestAnswerBuilder:
output = component.run(query="query", replies=["reply1"])
answers = output["answers"]
assert answers[0].data == "reply1"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -27,7 +27,7 @@ class TestAnswerBuilder:
output = component.run(query="query", replies=["reply1"], metadata=[])
answers = output["answers"]
assert answers[0].data == "reply1"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -38,7 +38,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -49,7 +49,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -60,7 +60,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "'AnswerString'"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -77,7 +77,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "AnswerString"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert answers[0].documents == []
assert isinstance(answers[0], GeneratedAnswer)
@ -93,7 +93,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 2
assert answers[0].documents[0].content == "test doc 1"
@ -110,7 +110,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString[2]"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 1
assert answers[0].documents[0].content == "test doc 2"
@ -127,7 +127,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString[3]"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 0
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
@ -144,7 +144,7 @@ class TestAnswerBuilder:
answers = output["answers"]
assert len(answers) == 1
assert answers[0].data == "Answer: AnswerString[2][3]"
assert answers[0].metadata == {}
assert answers[0].meta == {}
assert answers[0].query == "test query"
assert len(answers[0].documents) == 2
assert answers[0].documents[0].content == "test doc 2"

View File

@ -140,15 +140,15 @@ def test_output(mock_reader: ExtractiveReader):
doc_ids = set()
no_answer_prob = 1
for doc, answer in zip(example_documents[0], answers[:3]):
assert answer.start == 11
assert answer.end == 16
assert answer.document_offset.start == 11
assert answer.document_offset.end == 16
assert doc.content is not None
assert answer.data == doc.content[11:16]
assert answer.probability == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor)))
no_answer_prob *= 1 - answer.probability
assert answer.score == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor)))
no_answer_prob *= 1 - answer.score
doc_ids.add(doc.id)
assert len(doc_ids) == 3
assert answers[-1].probability == pytest.approx(no_answer_prob)
assert answers[-1].score == pytest.approx(no_answer_prob)
def test_flatten_documents(mock_reader: ExtractiveReader):
@ -241,14 +241,14 @@ def test_nest_answers(mock_reader: ExtractiveReader):
example_queries, nested_answers, expected_no_answers, [probabilities[:3, -1], probabilities[3:, -1]]
):
assert len(answers) == 4
for doc, answer, probability in zip(example_documents[0], reversed(answers[:3]), probabilities):
for doc, answer, score in zip(example_documents[0], reversed(answers[:3]), probabilities):
assert answer.query == query
assert answer.document == doc
assert answer.probability == pytest.approx(probability)
assert answer.score == pytest.approx(score)
no_answer = answers[-1]
assert no_answer.query == query
assert no_answer.document is None
assert no_answer.probability == pytest.approx(expected_no_answer)
assert no_answer.score == pytest.approx(expected_no_answer)
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
@ -269,19 +269,19 @@ def test_t5():
"answers"
] # remove indices when batching support is reintroduced
assert answers[0].data == "Angela Merkel"
assert answers[0].probability == pytest.approx(0.7764519453048706)
assert answers[0].score == pytest.approx(0.7764519453048706)
assert answers[1].data == "Olaf Scholz"
assert answers[1].probability == pytest.approx(0.7703777551651001)
assert answers[1].score == pytest.approx(0.7703777551651001)
assert answers[2].data is None
assert answers[2].probability == pytest.approx(0.051331606147570596)
assert answers[2].score == pytest.approx(0.051331606147570596)
# Uncomment assertions below when batching is reintroduced
# assert answers[0][2].probability == pytest.approx(0.051331606147570596)
# assert answers[0][2].score == pytest.approx(0.051331606147570596)
# assert answers[1][0].data == "Jerry"
# assert answers[1][0].probability == pytest.approx(0.7413333654403687)
# assert answers[1][0].score == pytest.approx(0.7413333654403687)
# assert answers[1][1].data == "Olaf Scholz"
# assert answers[1][1].probability == pytest.approx(0.7266613841056824)
# assert answers[1][1].score == pytest.approx(0.7266613841056824)
# assert answers[1][2].data is None
# assert answers[1][2].probability == pytest.approx(0.0707035798685709)
# assert answers[1][2].score == pytest.approx(0.0707035798685709)
@pytest.mark.integration
@ -292,24 +292,24 @@ def test_roberta():
"answers"
] # remove indices when batching is reintroduced
assert answers[0].data == "Olaf Scholz"
assert answers[0].probability == pytest.approx(0.8614975214004517)
assert answers[0].score == pytest.approx(0.8614975214004517)
assert answers[1].data == "Angela Merkel"
assert answers[1].probability == pytest.approx(0.857952892780304)
assert answers[1].score == pytest.approx(0.857952892780304)
assert answers[2].data is None
assert answers[2].probability == pytest.approx(0.019673851661650588)
assert answers[2].score == pytest.approx(0.019673851661650588)
# uncomment assertions below when there is batching in v2
# assert answers[0][0].data == "Olaf Scholz"
# assert answers[0][0].probability == pytest.approx(0.8614975214004517)
# assert answers[0][0].score == pytest.approx(0.8614975214004517)
# assert answers[0][1].data == "Angela Merkel"
# assert answers[0][1].probability == pytest.approx(0.857952892780304)
# assert answers[0][1].score == pytest.approx(0.857952892780304)
# assert answers[0][2].data is None
# assert answers[0][2].probability == pytest.approx(0.0196738764278237)
# assert answers[0][2].score == pytest.approx(0.0196738764278237)
# assert answers[1][0].data == "Jerry"
# assert answers[1][0].probability == pytest.approx(0.7048940658569336)
# assert answers[1][0].score == pytest.approx(0.7048940658569336)
# assert answers[1][1].data == "Olaf Scholz"
# assert answers[1][1].probability == pytest.approx(0.6604189872741699)
# assert answers[1][1].score == pytest.approx(0.6604189872741699)
# assert answers[1][2].data is None
# assert answers[1][2].probability == pytest.approx(0.1002123719777046)
# assert answers[1][2].score == pytest.approx(0.1002123719777046)
@pytest.mark.integration
@ -329,6 +329,6 @@ def test_matches_hf_pipeline():
) # We need to disable HF postprocessing features to make the results comparable. This is related to https://github.com/huggingface/transformers/issues/26286
assert len(answers) == len(answers_hf) == 20
for answer, answer_hf in zip(answers, answers_hf):
assert answer.start == answer_hf["start"]
assert answer.end == answer_hf["end"]
assert answer.document_offset.start == answer_hf["start"]
assert answer.document_offset.end == answer_hf["end"]
assert answer.data == answer_hf["answer"]

View File

@ -0,0 +1,268 @@
from pandas import DataFrame
from haystack.dataclasses.answer import Answer, ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer
from haystack.dataclasses.document import Document
class TestExtractedAnswer:
def test_init(self):
answer = ExtractedAnswer(
data="42",
query="What is the answer?",
document=Document(content="I thought a lot about this. The answer is 42."),
context="The answer is 42.",
score=1.0,
document_offset=ExtractedAnswer.Span(42, 44),
context_offset=ExtractedAnswer.Span(14, 16),
meta={"meta_key": "meta_value"},
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.document == Document(content="I thought a lot about this. The answer is 42.")
assert answer.context == "The answer is 42."
assert answer.score == 1.0
assert answer.document_offset == ExtractedAnswer.Span(42, 44)
assert answer.context_offset == ExtractedAnswer.Span(14, 16)
assert answer.meta == {"meta_key": "meta_value"}
def test_protocol(self):
answer = ExtractedAnswer(
data="42",
query="What is the answer?",
document=Document(content="I thought a lot about this. The answer is 42."),
context="The answer is 42.",
score=1.0,
document_offset=ExtractedAnswer.Span(42, 44),
context_offset=ExtractedAnswer.Span(14, 16),
meta={"meta_key": "meta_value"},
)
assert isinstance(answer, Answer)
def test_to_dict(self):
document = Document(content="I thought a lot about this. The answer is 42.")
answer = ExtractedAnswer(
data="42",
query="What is the answer?",
document=document,
context="The answer is 42.",
score=1.0,
document_offset=ExtractedAnswer.Span(42, 44),
context_offset=ExtractedAnswer.Span(14, 16),
meta={"meta_key": "meta_value"},
)
assert answer.to_dict() == {
"type": "haystack.dataclasses.answer.ExtractedAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"document": document.to_dict(flatten=False),
"context": "The answer is 42.",
"score": 1.0,
"document_offset": {"start": 42, "end": 44},
"context_offset": {"start": 14, "end": 16},
"meta": {"meta_key": "meta_value"},
},
}
def test_from_dict(self):
answer = ExtractedAnswer.from_dict(
{
"type": "haystack.dataclasses.answer.ExtractedAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"document": {
"id": "8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f",
"content": "I thought a lot about this. The answer is 42.",
},
"context": "The answer is 42.",
"score": 1.0,
"document_offset": {"start": 42, "end": 44},
"context_offset": {"start": 14, "end": 16},
"meta": {"meta_key": "meta_value"},
},
}
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.document == Document(
id="8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f",
content="I thought a lot about this. The answer is 42.",
)
assert answer.context == "The answer is 42."
assert answer.score == 1.0
assert answer.document_offset == ExtractedAnswer.Span(42, 44)
assert answer.context_offset == ExtractedAnswer.Span(14, 16)
assert answer.meta == {"meta_key": "meta_value"}
class TestExtractedTableAnswer:
def test_init(self):
answer = ExtractedTableAnswer(
data="42",
query="What is the answer?",
document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})),
context=DataFrame({"col3": [5, 42]}),
score=1.0,
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
meta={"meta_key": "meta_value"},
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.document == Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}))
assert answer.context.equals(DataFrame({"col3": [5, 42]}))
assert answer.score == 1.0
assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)]
assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)]
assert answer.meta == {"meta_key": "meta_value"}
def test_protocol(self):
answer = ExtractedTableAnswer(
data="42",
query="What is the answer?",
document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})),
context=DataFrame({"col3": [5, 42]}),
score=1.0,
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
meta={"meta_key": "meta_value"},
)
assert isinstance(answer, Answer)
def test_to_dict(self):
document = Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}))
answer = ExtractedTableAnswer(
data="42",
query="What is the answer?",
document=document,
context=DataFrame({"col3": [5, 42]}),
score=1.0,
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
meta={"meta_key": "meta_value"},
)
assert answer.to_dict() == {
"type": "haystack.dataclasses.answer.ExtractedTableAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"document": document.to_dict(flatten=False),
"context": DataFrame({"col3": [5, 42]}).to_json(),
"score": 1.0,
"document_cells": [{"row": 1, "column": 2}],
"context_cells": [{"row": 1, "column": 0}],
"meta": {"meta_key": "meta_value"},
},
}
def test_from_dict(self):
answer = ExtractedTableAnswer.from_dict(
{
"type": "haystack.dataclasses.answer.ExtractedTableAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"document": {
"id": "3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed",
"dataframe": '{"col1":{"0":1,"1":2},"col2":{"0":3,"1":4},"col3":{"0":5,"1":42}}',
},
"context": '{"col3":{"0":5,"1":42}}',
"score": 1.0,
"document_cells": [{"row": 1, "column": 2}],
"context_cells": [{"row": 1, "column": 0}],
"meta": {"meta_key": "meta_value"},
},
}
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.document == Document(
id="3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed",
dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}),
)
assert answer.context.equals(DataFrame({"col3": [5, 42]}))
assert answer.score == 1.0
assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)]
assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)]
assert answer.meta == {"meta_key": "meta_value"}
class TestGeneratedAnswer:
def test_init(self):
answer = GeneratedAnswer(
data="42",
query="What is the answer?",
documents=[
Document(id="1", content="The answer is 42."),
Document(id="2", content="I believe the answer is 42."),
Document(id="3", content="42 is definitely the answer."),
],
meta={"meta_key": "meta_value"},
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.documents == [
Document(id="1", content="The answer is 42."),
Document(id="2", content="I believe the answer is 42."),
Document(id="3", content="42 is definitely the answer."),
]
assert answer.meta == {"meta_key": "meta_value"}
def test_protocol(self):
answer = GeneratedAnswer(
data="42",
query="What is the answer?",
documents=[
Document(id="1", content="The answer is 42."),
Document(id="2", content="I believe the answer is 42."),
Document(id="3", content="42 is definitely the answer."),
],
meta={"meta_key": "meta_value"},
)
assert isinstance(answer, Answer)
def test_to_dict(self):
documents = [
Document(id="1", content="The answer is 42."),
Document(id="2", content="I believe the answer is 42."),
Document(id="3", content="42 is definitely the answer."),
]
answer = GeneratedAnswer(
data="42", query="What is the answer?", documents=documents, meta={"meta_key": "meta_value"}
)
assert answer.to_dict() == {
"type": "haystack.dataclasses.answer.GeneratedAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"documents": [d.to_dict(flatten=False) for d in documents],
"meta": {"meta_key": "meta_value"},
},
}
def test_from_dict(self):
answer = GeneratedAnswer.from_dict(
{
"type": "haystack.dataclasses.answer.GeneratedAnswer",
"init_parameters": {
"data": "42",
"query": "What is the answer?",
"documents": [
{"id": "1", "content": "The answer is 42."},
{"id": "2", "content": "I believe the answer is 42."},
{"id": "3", "content": "42 is definitely the answer."},
],
"meta": {"meta_key": "meta_value"},
},
}
)
assert answer.data == "42"
assert answer.query == "What is the answer?"
assert answer.documents == [
Document(id="1", content="The answer is 42."),
Document(id="2", content="I believe the answer is 42."),
Document(id="3", content="42 is definitely the answer."),
]
assert answer.meta == {"meta_key": "meta_value"}

View File

@ -1,5 +1,3 @@
from pathlib import Path
import pandas as pd
import pytest