mirror of
https://github.com/deepset-ai/haystack.git
synced 2026-01-06 03:57:19 +00:00
refacotr: Refactor answer dataclasses (#6523)
* Refactor answer dataclasses * Add release notes * Fix tests * Fix end to end tests * Enhance ExtractiveReader
This commit is contained in:
parent
820d9c37d5
commit
18dbce25fc
@ -51,15 +51,14 @@ def test_extractive_qa_pipeline(tmp_path):
|
||||
# no_answer
|
||||
assert extracted_answers[-1].data is None
|
||||
|
||||
# since these questions are easily answerable, the best answer should have higher probability than no_answer
|
||||
assert extracted_answers[0].probability >= extracted_answers[-1].probability
|
||||
# since these questions are easily answerable, the best answer should have higher score than no_answer
|
||||
assert extracted_answers[0].score >= extracted_answers[-1].score
|
||||
|
||||
for answer in extracted_answers:
|
||||
assert answer.query == question
|
||||
|
||||
assert hasattr(answer, "probability")
|
||||
assert hasattr(answer, "start")
|
||||
assert hasattr(answer, "end")
|
||||
assert hasattr(answer, "score")
|
||||
assert hasattr(answer, "document_offset")
|
||||
|
||||
assert hasattr(answer, "document")
|
||||
# the answer is extracted from the correct document
|
||||
|
||||
@ -75,7 +75,7 @@ def test_bm25_rag_pipeline(tmp_path):
|
||||
assert spyword in generated_answer.data
|
||||
assert generated_answer.query == question
|
||||
assert hasattr(generated_answer, "documents")
|
||||
assert hasattr(generated_answer, "metadata")
|
||||
assert hasattr(generated_answer, "meta")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
@ -156,4 +156,4 @@ def test_embedding_retrieval_rag_pipeline(tmp_path):
|
||||
assert spyword in generated_answer.data
|
||||
assert generated_answer.query == question
|
||||
assert hasattr(generated_answer, "documents")
|
||||
assert hasattr(generated_answer, "metadata")
|
||||
assert hasattr(generated_answer, "meta")
|
||||
|
||||
@ -102,7 +102,7 @@ class AnswerBuilder:
|
||||
logger.warning("Document index '%s' referenced in Generator output is out of range. ", idx + 1)
|
||||
|
||||
answer_string = AnswerBuilder._extract_answer_string(reply, pattern)
|
||||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, metadata=meta)
|
||||
answer = GeneratedAnswer(data=answer_string, query=query, documents=referenced_docs, meta=meta)
|
||||
all_answers.append(answer)
|
||||
|
||||
return {"answers": all_answers}
|
||||
|
||||
@ -273,44 +273,42 @@ class ExtractiveReader:
|
||||
logit introduced with SQuAD 2. Instead, it just computes the probability that the answer does not exist
|
||||
in the top k or top p.
|
||||
"""
|
||||
flat_answers_without_queries = []
|
||||
answers_without_query = []
|
||||
for document_id, start_candidates_, end_candidates_, probabilities_ in zip(
|
||||
document_ids, start, end, probabilities
|
||||
):
|
||||
for start_, end_, probability in zip(start_candidates_, end_candidates_, probabilities_):
|
||||
doc = flattened_documents[document_id]
|
||||
# doc.content cannot be None, because those documents are filtered when preprocessing.
|
||||
# However, mypy doesn't know that.
|
||||
flat_answers_without_queries.append(
|
||||
{
|
||||
"data": doc.content[start_:end_], # type: ignore
|
||||
"document": doc,
|
||||
"probability": probability.item(),
|
||||
"start": start_,
|
||||
"end": end_,
|
||||
"metadata": {},
|
||||
}
|
||||
answers_without_query.append(
|
||||
ExtractedAnswer(
|
||||
query="", # Can't be None but we'll add it later
|
||||
data=doc.content[start_:end_], # type: ignore
|
||||
document=doc,
|
||||
score=probability.item(),
|
||||
document_offset=ExtractedAnswer.Span(start_, end_),
|
||||
meta={},
|
||||
)
|
||||
)
|
||||
i = 0
|
||||
nested_answers = []
|
||||
for query_id in range(query_ids[-1] + 1):
|
||||
current_answers = []
|
||||
while i < len(flat_answers_without_queries) and query_ids[i // answers_per_seq] == query_id:
|
||||
answer = flat_answers_without_queries[i]
|
||||
answer["query"] = queries[query_id]
|
||||
current_answers.append(ExtractedAnswer(**answer))
|
||||
while i < len(answers_without_query) and query_ids[i // answers_per_seq] == query_id:
|
||||
answer = answers_without_query[i]
|
||||
answer.query = queries[query_id]
|
||||
current_answers.append(answer)
|
||||
i += 1
|
||||
current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True)
|
||||
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
|
||||
current_answers = current_answers[:top_k]
|
||||
if no_answer:
|
||||
no_answer_probability = math.prod(1 - answer.probability for answer in current_answers)
|
||||
no_answer_score = math.prod(1 - answer.score for answer in current_answers)
|
||||
answer_ = ExtractedAnswer(
|
||||
data=None, query=queries[query_id], metadata={}, document=None, probability=no_answer_probability
|
||||
data=None, query=queries[query_id], meta={}, document=None, score=no_answer_score
|
||||
)
|
||||
current_answers.append(answer_)
|
||||
current_answers = sorted(current_answers, key=lambda answer: answer.probability, reverse=True)
|
||||
current_answers = sorted(current_answers, key=lambda answer: answer.score, reverse=True)
|
||||
if confidence_threshold is not None:
|
||||
current_answers = [answer for answer in current_answers if answer.probability >= confidence_threshold]
|
||||
current_answers = [answer for answer in current_answers if answer.score >= confidence_threshold]
|
||||
nested_answers.append(current_answers)
|
||||
|
||||
return nested_answers
|
||||
|
||||
@ -1,25 +1,139 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
import io
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
from dataclasses import dataclass, field, asdict
|
||||
|
||||
from pandas import DataFrame, read_json
|
||||
|
||||
from haystack.core.serialization import default_from_dict, default_to_dict
|
||||
from haystack.dataclasses.document import Document
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Answer:
|
||||
@runtime_checkable
|
||||
@dataclass
|
||||
class Answer(Protocol):
|
||||
data: Any
|
||||
query: str
|
||||
metadata: Dict[str, Any]
|
||||
meta: Dict[str, Any]
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Answer":
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExtractedAnswer(Answer):
|
||||
data: Optional[str]
|
||||
document: Optional[Document]
|
||||
probability: float
|
||||
start: Optional[int] = None
|
||||
end: Optional[int] = None
|
||||
@dataclass
|
||||
class ExtractedAnswer:
|
||||
query: str
|
||||
score: float
|
||||
data: Optional[str] = None
|
||||
document: Optional[Document] = None
|
||||
context: Optional[str] = None
|
||||
document_offset: Optional["Span"] = None
|
||||
context_offset: Optional["Span"] = None
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class Span:
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
document = self.document.to_dict(flatten=False) if self.document is not None else None
|
||||
document_offset = asdict(self.document_offset) if self.document_offset is not None else None
|
||||
context_offset = asdict(self.context_offset) if self.context_offset is not None else None
|
||||
return default_to_dict(
|
||||
self,
|
||||
data=self.data,
|
||||
query=self.query,
|
||||
document=document,
|
||||
context=self.context,
|
||||
score=self.score,
|
||||
document_offset=document_offset,
|
||||
context_offset=context_offset,
|
||||
meta=self.meta,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExtractedAnswer":
|
||||
init_params = data.get("init_parameters", {})
|
||||
if (doc := init_params.get("document")) is not None:
|
||||
data["init_parameters"]["document"] = Document.from_dict(doc)
|
||||
|
||||
if (offset := init_params.get("document_offset")) is not None:
|
||||
data["init_parameters"]["document_offset"] = ExtractedAnswer.Span(**offset)
|
||||
|
||||
if (offset := init_params.get("context_offset")) is not None:
|
||||
data["init_parameters"]["context_offset"] = ExtractedAnswer.Span(**offset)
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeneratedAnswer(Answer):
|
||||
@dataclass
|
||||
class ExtractedTableAnswer:
|
||||
query: str
|
||||
score: float
|
||||
data: Optional[str] = None
|
||||
document: Optional[Document] = None
|
||||
context: Optional[DataFrame] = None
|
||||
document_cells: List["Cell"] = field(default_factory=list)
|
||||
context_cells: List["Cell"] = field(default_factory=list)
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@dataclass
|
||||
class Cell:
|
||||
row: int
|
||||
column: int
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
document = self.document.to_dict(flatten=False) if self.document is not None else None
|
||||
context = self.context.to_json() if self.context is not None else None
|
||||
document_cells = [asdict(c) for c in self.document_cells]
|
||||
context_cells = [asdict(c) for c in self.context_cells]
|
||||
return default_to_dict(
|
||||
self,
|
||||
data=self.data,
|
||||
query=self.query,
|
||||
document=document,
|
||||
context=context,
|
||||
score=self.score,
|
||||
document_cells=document_cells,
|
||||
context_cells=context_cells,
|
||||
meta=self.meta,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ExtractedTableAnswer":
|
||||
init_params = data.get("init_parameters", {})
|
||||
if (doc := init_params.get("document")) is not None:
|
||||
data["init_parameters"]["document"] = Document.from_dict(doc)
|
||||
|
||||
if (context := init_params.get("context")) is not None:
|
||||
data["init_parameters"]["context"] = read_json(io.StringIO(context))
|
||||
|
||||
if (cells := init_params.get("document_cells")) is not None:
|
||||
data["init_parameters"]["document_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells]
|
||||
|
||||
if (cells := init_params.get("context_cells")) is not None:
|
||||
data["init_parameters"]["context_cells"] = [ExtractedTableAnswer.Cell(**c) for c in cells]
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedAnswer:
|
||||
data: str
|
||||
query: str
|
||||
documents: List[Document]
|
||||
meta: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
documents = [doc.to_dict(flatten=False) for doc in self.documents]
|
||||
return default_to_dict(self, data=self.data, query=self.query, documents=documents, meta=self.meta)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "GeneratedAnswer":
|
||||
init_params = data.get("init_parameters", {})
|
||||
if (documents := init_params.get("documents")) is not None:
|
||||
data["init_parameters"]["documents"] = [Document.from_dict(d) for d in documents]
|
||||
|
||||
return default_from_dict(cls, data)
|
||||
|
||||
@ -0,0 +1,8 @@
|
||||
---
|
||||
enhancements:
|
||||
- |
|
||||
Refactor `Answer` dataclass and classes that inherited it.
|
||||
Now `Answer` is a Protocol, classes that used to inherit it now respect that interface.
|
||||
We also added a new `ExtractiveTableAnswer` to be used for table question answering.
|
||||
|
||||
All classes now are easily serializable using `to_dict()` and `from_dict()` like `Document` and components.
|
||||
@ -17,7 +17,7 @@ class TestAnswerBuilder:
|
||||
output = component.run(query="query", replies=["reply1"])
|
||||
answers = output["answers"]
|
||||
assert answers[0].data == "reply1"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -27,7 +27,7 @@ class TestAnswerBuilder:
|
||||
output = component.run(query="query", replies=["reply1"], metadata=[])
|
||||
answers = output["answers"]
|
||||
assert answers[0].data == "reply1"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -38,7 +38,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -49,7 +49,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -60,7 +60,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "'AnswerString'"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -77,7 +77,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert answers[0].documents == []
|
||||
assert isinstance(answers[0], GeneratedAnswer)
|
||||
@ -93,7 +93,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].content == "test doc 1"
|
||||
@ -110,7 +110,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString[2]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 1
|
||||
assert answers[0].documents[0].content == "test doc 2"
|
||||
@ -127,7 +127,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString[3]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 0
|
||||
assert "Document index '3' referenced in Generator output is out of range." in caplog.text
|
||||
@ -144,7 +144,7 @@ class TestAnswerBuilder:
|
||||
answers = output["answers"]
|
||||
assert len(answers) == 1
|
||||
assert answers[0].data == "Answer: AnswerString[2][3]"
|
||||
assert answers[0].metadata == {}
|
||||
assert answers[0].meta == {}
|
||||
assert answers[0].query == "test query"
|
||||
assert len(answers[0].documents) == 2
|
||||
assert answers[0].documents[0].content == "test doc 2"
|
||||
|
||||
@ -140,15 +140,15 @@ def test_output(mock_reader: ExtractiveReader):
|
||||
doc_ids = set()
|
||||
no_answer_prob = 1
|
||||
for doc, answer in zip(example_documents[0], answers[:3]):
|
||||
assert answer.start == 11
|
||||
assert answer.end == 16
|
||||
assert answer.document_offset.start == 11
|
||||
assert answer.document_offset.end == 16
|
||||
assert doc.content is not None
|
||||
assert answer.data == doc.content[11:16]
|
||||
assert answer.probability == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor)))
|
||||
no_answer_prob *= 1 - answer.probability
|
||||
assert answer.score == pytest.approx(1 / (1 + exp(-2 * mock_reader.calibration_factor)))
|
||||
no_answer_prob *= 1 - answer.score
|
||||
doc_ids.add(doc.id)
|
||||
assert len(doc_ids) == 3
|
||||
assert answers[-1].probability == pytest.approx(no_answer_prob)
|
||||
assert answers[-1].score == pytest.approx(no_answer_prob)
|
||||
|
||||
|
||||
def test_flatten_documents(mock_reader: ExtractiveReader):
|
||||
@ -241,14 +241,14 @@ def test_nest_answers(mock_reader: ExtractiveReader):
|
||||
example_queries, nested_answers, expected_no_answers, [probabilities[:3, -1], probabilities[3:, -1]]
|
||||
):
|
||||
assert len(answers) == 4
|
||||
for doc, answer, probability in zip(example_documents[0], reversed(answers[:3]), probabilities):
|
||||
for doc, answer, score in zip(example_documents[0], reversed(answers[:3]), probabilities):
|
||||
assert answer.query == query
|
||||
assert answer.document == doc
|
||||
assert answer.probability == pytest.approx(probability)
|
||||
assert answer.score == pytest.approx(score)
|
||||
no_answer = answers[-1]
|
||||
assert no_answer.query == query
|
||||
assert no_answer.document is None
|
||||
assert no_answer.probability == pytest.approx(expected_no_answer)
|
||||
assert no_answer.score == pytest.approx(expected_no_answer)
|
||||
|
||||
|
||||
@patch("haystack.components.readers.extractive.AutoTokenizer.from_pretrained")
|
||||
@ -269,19 +269,19 @@ def test_t5():
|
||||
"answers"
|
||||
] # remove indices when batching support is reintroduced
|
||||
assert answers[0].data == "Angela Merkel"
|
||||
assert answers[0].probability == pytest.approx(0.7764519453048706)
|
||||
assert answers[0].score == pytest.approx(0.7764519453048706)
|
||||
assert answers[1].data == "Olaf Scholz"
|
||||
assert answers[1].probability == pytest.approx(0.7703777551651001)
|
||||
assert answers[1].score == pytest.approx(0.7703777551651001)
|
||||
assert answers[2].data is None
|
||||
assert answers[2].probability == pytest.approx(0.051331606147570596)
|
||||
assert answers[2].score == pytest.approx(0.051331606147570596)
|
||||
# Uncomment assertions below when batching is reintroduced
|
||||
# assert answers[0][2].probability == pytest.approx(0.051331606147570596)
|
||||
# assert answers[0][2].score == pytest.approx(0.051331606147570596)
|
||||
# assert answers[1][0].data == "Jerry"
|
||||
# assert answers[1][0].probability == pytest.approx(0.7413333654403687)
|
||||
# assert answers[1][0].score == pytest.approx(0.7413333654403687)
|
||||
# assert answers[1][1].data == "Olaf Scholz"
|
||||
# assert answers[1][1].probability == pytest.approx(0.7266613841056824)
|
||||
# assert answers[1][1].score == pytest.approx(0.7266613841056824)
|
||||
# assert answers[1][2].data is None
|
||||
# assert answers[1][2].probability == pytest.approx(0.0707035798685709)
|
||||
# assert answers[1][2].score == pytest.approx(0.0707035798685709)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -292,24 +292,24 @@ def test_roberta():
|
||||
"answers"
|
||||
] # remove indices when batching is reintroduced
|
||||
assert answers[0].data == "Olaf Scholz"
|
||||
assert answers[0].probability == pytest.approx(0.8614975214004517)
|
||||
assert answers[0].score == pytest.approx(0.8614975214004517)
|
||||
assert answers[1].data == "Angela Merkel"
|
||||
assert answers[1].probability == pytest.approx(0.857952892780304)
|
||||
assert answers[1].score == pytest.approx(0.857952892780304)
|
||||
assert answers[2].data is None
|
||||
assert answers[2].probability == pytest.approx(0.019673851661650588)
|
||||
assert answers[2].score == pytest.approx(0.019673851661650588)
|
||||
# uncomment assertions below when there is batching in v2
|
||||
# assert answers[0][0].data == "Olaf Scholz"
|
||||
# assert answers[0][0].probability == pytest.approx(0.8614975214004517)
|
||||
# assert answers[0][0].score == pytest.approx(0.8614975214004517)
|
||||
# assert answers[0][1].data == "Angela Merkel"
|
||||
# assert answers[0][1].probability == pytest.approx(0.857952892780304)
|
||||
# assert answers[0][1].score == pytest.approx(0.857952892780304)
|
||||
# assert answers[0][2].data is None
|
||||
# assert answers[0][2].probability == pytest.approx(0.0196738764278237)
|
||||
# assert answers[0][2].score == pytest.approx(0.0196738764278237)
|
||||
# assert answers[1][0].data == "Jerry"
|
||||
# assert answers[1][0].probability == pytest.approx(0.7048940658569336)
|
||||
# assert answers[1][0].score == pytest.approx(0.7048940658569336)
|
||||
# assert answers[1][1].data == "Olaf Scholz"
|
||||
# assert answers[1][1].probability == pytest.approx(0.6604189872741699)
|
||||
# assert answers[1][1].score == pytest.approx(0.6604189872741699)
|
||||
# assert answers[1][2].data is None
|
||||
# assert answers[1][2].probability == pytest.approx(0.1002123719777046)
|
||||
# assert answers[1][2].score == pytest.approx(0.1002123719777046)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@ -329,6 +329,6 @@ def test_matches_hf_pipeline():
|
||||
) # We need to disable HF postprocessing features to make the results comparable. This is related to https://github.com/huggingface/transformers/issues/26286
|
||||
assert len(answers) == len(answers_hf) == 20
|
||||
for answer, answer_hf in zip(answers, answers_hf):
|
||||
assert answer.start == answer_hf["start"]
|
||||
assert answer.end == answer_hf["end"]
|
||||
assert answer.document_offset.start == answer_hf["start"]
|
||||
assert answer.document_offset.end == answer_hf["end"]
|
||||
assert answer.data == answer_hf["answer"]
|
||||
|
||||
268
test/dataclasses/test_answer.py
Normal file
268
test/dataclasses/test_answer.py
Normal file
@ -0,0 +1,268 @@
|
||||
from pandas import DataFrame
|
||||
|
||||
from haystack.dataclasses.answer import Answer, ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer
|
||||
from haystack.dataclasses.document import Document
|
||||
|
||||
|
||||
class TestExtractedAnswer:
|
||||
def test_init(self):
|
||||
answer = ExtractedAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=Document(content="I thought a lot about this. The answer is 42."),
|
||||
context="The answer is 42.",
|
||||
score=1.0,
|
||||
document_offset=ExtractedAnswer.Span(42, 44),
|
||||
context_offset=ExtractedAnswer.Span(14, 16),
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.document == Document(content="I thought a lot about this. The answer is 42.")
|
||||
assert answer.context == "The answer is 42."
|
||||
assert answer.score == 1.0
|
||||
assert answer.document_offset == ExtractedAnswer.Span(42, 44)
|
||||
assert answer.context_offset == ExtractedAnswer.Span(14, 16)
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
|
||||
def test_protocol(self):
|
||||
answer = ExtractedAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=Document(content="I thought a lot about this. The answer is 42."),
|
||||
context="The answer is 42.",
|
||||
score=1.0,
|
||||
document_offset=ExtractedAnswer.Span(42, 44),
|
||||
context_offset=ExtractedAnswer.Span(14, 16),
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert isinstance(answer, Answer)
|
||||
|
||||
def test_to_dict(self):
|
||||
document = Document(content="I thought a lot about this. The answer is 42.")
|
||||
answer = ExtractedAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=document,
|
||||
context="The answer is 42.",
|
||||
score=1.0,
|
||||
document_offset=ExtractedAnswer.Span(42, 44),
|
||||
context_offset=ExtractedAnswer.Span(14, 16),
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert answer.to_dict() == {
|
||||
"type": "haystack.dataclasses.answer.ExtractedAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"document": document.to_dict(flatten=False),
|
||||
"context": "The answer is 42.",
|
||||
"score": 1.0,
|
||||
"document_offset": {"start": 42, "end": 44},
|
||||
"context_offset": {"start": 14, "end": 16},
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
answer = ExtractedAnswer.from_dict(
|
||||
{
|
||||
"type": "haystack.dataclasses.answer.ExtractedAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"document": {
|
||||
"id": "8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f",
|
||||
"content": "I thought a lot about this. The answer is 42.",
|
||||
},
|
||||
"context": "The answer is 42.",
|
||||
"score": 1.0,
|
||||
"document_offset": {"start": 42, "end": 44},
|
||||
"context_offset": {"start": 14, "end": 16},
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.document == Document(
|
||||
id="8f800a524b139484fc719ecc35f971a080de87618319bc4836b784d69baca57f",
|
||||
content="I thought a lot about this. The answer is 42.",
|
||||
)
|
||||
assert answer.context == "The answer is 42."
|
||||
assert answer.score == 1.0
|
||||
assert answer.document_offset == ExtractedAnswer.Span(42, 44)
|
||||
assert answer.context_offset == ExtractedAnswer.Span(14, 16)
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
|
||||
|
||||
class TestExtractedTableAnswer:
|
||||
def test_init(self):
|
||||
answer = ExtractedTableAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})),
|
||||
context=DataFrame({"col3": [5, 42]}),
|
||||
score=1.0,
|
||||
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
|
||||
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.document == Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}))
|
||||
assert answer.context.equals(DataFrame({"col3": [5, 42]}))
|
||||
assert answer.score == 1.0
|
||||
assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)]
|
||||
assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)]
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
|
||||
def test_protocol(self):
|
||||
answer = ExtractedTableAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]})),
|
||||
context=DataFrame({"col3": [5, 42]}),
|
||||
score=1.0,
|
||||
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
|
||||
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert isinstance(answer, Answer)
|
||||
|
||||
def test_to_dict(self):
|
||||
document = Document(dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}))
|
||||
answer = ExtractedTableAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
document=document,
|
||||
context=DataFrame({"col3": [5, 42]}),
|
||||
score=1.0,
|
||||
document_cells=[ExtractedTableAnswer.Cell(1, 2)],
|
||||
context_cells=[ExtractedTableAnswer.Cell(1, 0)],
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
|
||||
assert answer.to_dict() == {
|
||||
"type": "haystack.dataclasses.answer.ExtractedTableAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"document": document.to_dict(flatten=False),
|
||||
"context": DataFrame({"col3": [5, 42]}).to_json(),
|
||||
"score": 1.0,
|
||||
"document_cells": [{"row": 1, "column": 2}],
|
||||
"context_cells": [{"row": 1, "column": 0}],
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
answer = ExtractedTableAnswer.from_dict(
|
||||
{
|
||||
"type": "haystack.dataclasses.answer.ExtractedTableAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"document": {
|
||||
"id": "3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed",
|
||||
"dataframe": '{"col1":{"0":1,"1":2},"col2":{"0":3,"1":4},"col3":{"0":5,"1":42}}',
|
||||
},
|
||||
"context": '{"col3":{"0":5,"1":42}}',
|
||||
"score": 1.0,
|
||||
"document_cells": [{"row": 1, "column": 2}],
|
||||
"context_cells": [{"row": 1, "column": 0}],
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.document == Document(
|
||||
id="3b13a0d56a3697e27a874fcb621911c83c59388dec213909e9e40d5d9f0affed",
|
||||
dataframe=DataFrame({"col1": [1, 2], "col2": [3, 4], "col3": [5, 42]}),
|
||||
)
|
||||
assert answer.context.equals(DataFrame({"col3": [5, 42]}))
|
||||
assert answer.score == 1.0
|
||||
assert answer.document_cells == [ExtractedTableAnswer.Cell(1, 2)]
|
||||
assert answer.context_cells == [ExtractedTableAnswer.Cell(1, 0)]
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
|
||||
|
||||
class TestGeneratedAnswer:
|
||||
def test_init(self):
|
||||
answer = GeneratedAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
documents=[
|
||||
Document(id="1", content="The answer is 42."),
|
||||
Document(id="2", content="I believe the answer is 42."),
|
||||
Document(id="3", content="42 is definitely the answer."),
|
||||
],
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.documents == [
|
||||
Document(id="1", content="The answer is 42."),
|
||||
Document(id="2", content="I believe the answer is 42."),
|
||||
Document(id="3", content="42 is definitely the answer."),
|
||||
]
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
|
||||
def test_protocol(self):
|
||||
answer = GeneratedAnswer(
|
||||
data="42",
|
||||
query="What is the answer?",
|
||||
documents=[
|
||||
Document(id="1", content="The answer is 42."),
|
||||
Document(id="2", content="I believe the answer is 42."),
|
||||
Document(id="3", content="42 is definitely the answer."),
|
||||
],
|
||||
meta={"meta_key": "meta_value"},
|
||||
)
|
||||
assert isinstance(answer, Answer)
|
||||
|
||||
def test_to_dict(self):
|
||||
documents = [
|
||||
Document(id="1", content="The answer is 42."),
|
||||
Document(id="2", content="I believe the answer is 42."),
|
||||
Document(id="3", content="42 is definitely the answer."),
|
||||
]
|
||||
answer = GeneratedAnswer(
|
||||
data="42", query="What is the answer?", documents=documents, meta={"meta_key": "meta_value"}
|
||||
)
|
||||
assert answer.to_dict() == {
|
||||
"type": "haystack.dataclasses.answer.GeneratedAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"documents": [d.to_dict(flatten=False) for d in documents],
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
|
||||
def test_from_dict(self):
|
||||
answer = GeneratedAnswer.from_dict(
|
||||
{
|
||||
"type": "haystack.dataclasses.answer.GeneratedAnswer",
|
||||
"init_parameters": {
|
||||
"data": "42",
|
||||
"query": "What is the answer?",
|
||||
"documents": [
|
||||
{"id": "1", "content": "The answer is 42."},
|
||||
{"id": "2", "content": "I believe the answer is 42."},
|
||||
{"id": "3", "content": "42 is definitely the answer."},
|
||||
],
|
||||
"meta": {"meta_key": "meta_value"},
|
||||
},
|
||||
}
|
||||
)
|
||||
assert answer.data == "42"
|
||||
assert answer.query == "What is the answer?"
|
||||
assert answer.documents == [
|
||||
Document(id="1", content="The answer is 42."),
|
||||
Document(id="2", content="I believe the answer is 42."),
|
||||
Document(id="3", content="42 is definitely the answer."),
|
||||
]
|
||||
assert answer.meta == {"meta_key": "meta_value"}
|
||||
@ -1,5 +1,3 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user