mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-31 11:56:35 +00:00
Add RouteDocuments
and JoinAnswers
nodes (#2256)
* Add SplitDocumentList and JoinAnswer nodes * Update Documentation & Code Style * Add tests + adapt tutorial * Update Documentation & Code Style * Remove branch from installation path in Tutorial * Update Documentation & Code Style * Fix typing * Update Documentation & Code Style * Change name of SplitDocumentList to RouteDocuments * Update Documentation & Code Style * Adapt tutorials to new name * Add test for JoinAnswers * Update Documentation & Code Style * Adapt name of test for JoinAnswers node Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
This commit is contained in:
parent
11eebf8097
commit
c5542bd3fb
@ -38,8 +38,9 @@ Make sure you enable the GPU runtime to experience decent speed in this tutorial
|
||||
# The TaPAs-based TableReader requires the torch-scatter library
|
||||
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html
|
||||
|
||||
# If you run this notebook on Google Colab, you might need to
|
||||
# restart the runtime after installing haystack.
|
||||
# Install pygraphviz for visualization of Pipelines
|
||||
!apt install libgraphviz-dev
|
||||
!pip install pygraphviz
|
||||
```
|
||||
|
||||
### Start an Elasticsearch server
|
||||
@ -94,7 +95,7 @@ Just as text passages, tables are represented as `Document` objects in Haystack.
|
||||
from haystack.utils import fetch_archive_from_http
|
||||
|
||||
doc_dir = "data"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/ottqa_tables_sample.json.zip"
|
||||
s3_url = "https://s3.eu-central-1.amazonaws.com/deepset.ai-farm-qa/datasets/documents/ottqa_sample.zip"
|
||||
fetch_archive_from_http(url=s3_url, output_dir=doc_dir)
|
||||
```
|
||||
|
||||
@ -246,6 +247,101 @@ prediction = table_qa_pipeline.run("How many twin buildings are under constructi
|
||||
print_answers(prediction, details="minimum")
|
||||
```
|
||||
|
||||
# Open-Domain QA on Text and Tables
|
||||
With haystack, you not only have the possibility to do QA on texts or tables, solely, but you can also use both texts and tables as your source of information.
|
||||
|
||||
To demonstrate this, we add 1,000 sample text passages from the OTT-QA dataset.
|
||||
|
||||
|
||||
```python
|
||||
# Add 1,000 text passages from OTT-QA to our document store.
|
||||
|
||||
|
||||
def read_ottqa_texts(filename):
|
||||
processed_passages = []
|
||||
with open(filename) as passages:
|
||||
passages = json.load(passages)
|
||||
for title, content in passages.items():
|
||||
title = title[6:]
|
||||
title = title.replace("_", " ")
|
||||
document = Document(content=content, content_type="text", meta={"title": title})
|
||||
processed_passages.append(document)
|
||||
|
||||
return processed_passages
|
||||
|
||||
|
||||
passages = read_ottqa_texts(f"{doc_dir}/ottqa_texts_sample.json")
|
||||
document_store.write_documents(passages, index=document_index)
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
document_store.update_embeddings(retriever=retriever, update_existing_embeddings=False)
|
||||
```
|
||||
|
||||
## Pipeline for QA on Combination of Text and Tables
|
||||
We are using one node for retrieving both texts and tables, the `TableTextRetriever`. In order to do question-answering on the Documents coming from the `TableTextRetriever`, we need to route Documents of type `"text"` to a `FARMReader` (or alternatively `TransformersReader`) and Documents of type `"table"` to a `TableReader`.
|
||||
|
||||
To achieve this, we make use of two additional nodes:
|
||||
- `SplitDocumentList`: Splits the List of Documents retrieved by the `TableTextRetriever` into two lists containing only Documents of type `"text"` or `"table"`, respectively.
|
||||
- `JoinAnswers`: Takes Answers coming from two different Readers (in this case `FARMReader` and `TableReader`) and joins them to a single list of Answers.
|
||||
|
||||
|
||||
```python
|
||||
from haystack.nodes import FARMReader, RouteDocuments, JoinAnswers
|
||||
|
||||
text_reader = FARMReader("deepset/roberta-base-squad2")
|
||||
# In order to get meaningful scores from the TableReader, use "deepset/tapas-large-nq-hn-reader" or
|
||||
# "deepset/tapas-large-nq-reader" as TableReader models. The disadvantage of these models is, however,
|
||||
# that they are not capable of doing aggregations over multiple table cells.
|
||||
table_reader = TableReader("deepset/tapas-large-nq-hn-reader")
|
||||
route_documents = RouteDocuments()
|
||||
join_answers = JoinAnswers()
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
text_table_qa_pipeline = Pipeline()
|
||||
text_table_qa_pipeline.add_node(component=retriever, name="TableTextRetriever", inputs=["Query"])
|
||||
text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["TableTextRetriever"])
|
||||
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"])
|
||||
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"])
|
||||
text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"])
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# Let's have a look on the structure of the combined Table an Text QA pipeline.
|
||||
from IPython import display
|
||||
|
||||
text_table_qa_pipeline.draw()
|
||||
display.Image("pipeline.png")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# Example query whose answer resides in a text passage
|
||||
predictions = text_table_qa_pipeline.run(query="Who is Aleksandar Trifunovic?")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# We can see both text passages and tables as contexts of the predicted answers.
|
||||
print_answers(predictions, details="minimum")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# Example query whose answer resides in a table
|
||||
predictions = text_table_qa_pipeline.run(query="What is Cuba's national tree?")
|
||||
```
|
||||
|
||||
|
||||
```python
|
||||
# We can see both text passages and tables as contexts of the predicted answers.
|
||||
print_answers(predictions, details="minimum")
|
||||
```
|
||||
|
||||
## About us
|
||||
|
||||
This [Haystack](https://github.com/deepset-ai/haystack/) notebook was made with love by [deepset](https://deepset.ai/) in Berlin, Germany
|
||||
|
@ -102,7 +102,7 @@ except ImportError:
|
||||
|
||||
from haystack.modeling.evaluation import eval
|
||||
from haystack.modeling.logger import MLFlowLogger, StdoutLogger, TensorBoardLogger
|
||||
from haystack.nodes.other import JoinDocuments, Docs2Answers
|
||||
from haystack.nodes.other import JoinDocuments, Docs2Answers, JoinAnswers, RouteDocuments
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
from haystack.nodes.file_classifier import FileTypeClassifier
|
||||
from haystack.utils import preprocessing
|
||||
|
@ -21,7 +21,7 @@ from haystack.nodes.file_converter import (
|
||||
AzureConverter,
|
||||
ParsrConverter,
|
||||
)
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments
|
||||
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers
|
||||
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
|
||||
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
|
||||
from haystack.nodes.question_generator import QuestionGenerator
|
||||
|
@ -1,2 +1,4 @@
|
||||
from haystack.nodes.other.docs2answers import Docs2Answers
|
||||
from haystack.nodes.other.join_docs import JoinDocuments
|
||||
from haystack.nodes.other.route_documents import RouteDocuments
|
||||
from haystack.nodes.other.join_answers import JoinAnswers
|
||||
|
64
haystack/nodes/other/join_answers.py
Normal file
64
haystack/nodes/other/join_answers.py
Normal file
@ -0,0 +1,64 @@
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
|
||||
from haystack.schema import Answer
|
||||
from haystack.nodes import BaseComponent
|
||||
|
||||
|
||||
class JoinAnswers(BaseComponent):
|
||||
"""
|
||||
A node to join `Answer`s produced by multiple `Reader` nodes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, join_mode: str = "concatenate", weights: Optional[List[float]] = None, top_k_join: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
:param join_mode: `"concatenate"` to combine documents from multiple `Reader`s. `"merge"` to aggregate scores
|
||||
of individual `Answer`s.
|
||||
:param weights: A node-wise list (length of list must be equal to the number of input nodes) of weights for
|
||||
adjusting `Answer` scores when using the `"merge"` join_mode. By default, equal weight is assigned to each
|
||||
`Reader` score. This parameter is not compatible with the `"concatenate"` join_mode.
|
||||
:param top_k_join: Limit `Answer`s to top_k based on the resulting scored of the join.
|
||||
"""
|
||||
|
||||
assert join_mode in ["concatenate", "merge"], f"JoinAnswers node does not support '{join_mode}' join_mode."
|
||||
assert not (
|
||||
weights is not None and join_mode == "concatenate"
|
||||
), "Weights are not compatible with 'concatenate' join_mode"
|
||||
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(join_mode=join_mode, weights=weights, top_k_join=top_k_join)
|
||||
|
||||
self.join_mode = join_mode
|
||||
self.weights = [float(i) / sum(weights) for i in weights] if weights else None
|
||||
self.top_k_join = top_k_join
|
||||
|
||||
def run(self, inputs: List[Dict], top_k_join: Optional[int] = None) -> Tuple[Dict, str]: # type: ignore
|
||||
reader_results = [inp["answers"] for inp in inputs]
|
||||
|
||||
if not top_k_join:
|
||||
top_k_join = self.top_k_join
|
||||
|
||||
if self.join_mode == "concatenate":
|
||||
concatenated_answers = [answer for cur_reader_result in reader_results for answer in cur_reader_result]
|
||||
concatenated_answers = sorted(concatenated_answers, reverse=True)[:top_k_join]
|
||||
return {"answers": concatenated_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||
|
||||
elif self.join_mode == "merge":
|
||||
merged_answers = self._merge_answers(reader_results)
|
||||
|
||||
merged_answers = merged_answers[:top_k_join]
|
||||
return {"answers": merged_answers, "labels": inputs[0].get("labels", None)}, "output_1"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid join_mode: {self.join_mode}")
|
||||
|
||||
def _merge_answers(self, reader_results: List[List[Answer]]) -> List[Answer]:
|
||||
weights = self.weights if self.weights else [1 / len(reader_results)] * len(reader_results)
|
||||
|
||||
for result, weight in zip(reader_results, weights):
|
||||
for answer in result:
|
||||
if isinstance(answer.score, float):
|
||||
answer.score *= weight
|
||||
|
||||
return sorted([answer for cur_reader_result in reader_results for answer in cur_reader_result], reverse=True)
|
72
haystack/nodes/other/route_documents.py
Normal file
72
haystack/nodes/other/route_documents.py
Normal file
@ -0,0 +1,72 @@
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
|
||||
from haystack.nodes.base import BaseComponent
|
||||
from haystack.schema import Document
|
||||
|
||||
|
||||
class RouteDocuments(BaseComponent):
|
||||
"""
|
||||
A node to split a list of `Document`s by `content_type` or by the values of a metadata field and route them to
|
||||
different nodes.
|
||||
"""
|
||||
|
||||
# By default (split_by == "content_type"), the node has two outgoing edges.
|
||||
outgoing_edges = 2
|
||||
|
||||
def __init__(self, split_by: str = "content_type", metadata_values: Optional[List[str]] = None):
|
||||
"""
|
||||
:param split_by: Field to split the documents by, either `"content_type"` or a metadata field name.
|
||||
If this parameter is set to `"content_type"`, the list of `Document`s will be split into a list containing
|
||||
only `Document`s of type `"text"` (will be routed to `"output_1"`) and a list containing only `Document`s of
|
||||
type `"text"` (will be routed to `"output_2"`).
|
||||
If this parameter is set to a metadata field name, you need to specify the parameter `metadata_values` as
|
||||
well.
|
||||
:param metadata_values: If the parameter `split_by` is set to a metadata field name, you need to provide a list
|
||||
of values to group the `Document`s to. `Document`s whose metadata field is equal to the first value of the
|
||||
provided list will be routed to `"output_1"`, `Document`s whose metadata field is equal to the second
|
||||
value of the provided list will be routed to `"output_2"`, etc.
|
||||
"""
|
||||
|
||||
assert split_by == "content_type" or metadata_values is not None, (
|
||||
"If split_by is set to the name of a metadata field, you must provide metadata_values "
|
||||
"to group the documents to."
|
||||
)
|
||||
|
||||
# Save init parameters to enable export of component config as YAML
|
||||
self.set_config(split_by=split_by, metadata_values=metadata_values)
|
||||
|
||||
self.split_by = split_by
|
||||
self.metadata_values = metadata_values
|
||||
|
||||
# If we split list of Documents by a metadata field, number of outgoing edges might change
|
||||
if split_by != "content_type" and metadata_values is not None:
|
||||
self.outgoing_edges = len(metadata_values)
|
||||
|
||||
def run(self, documents: List[Document]) -> Tuple[Dict, str]: # type: ignore
|
||||
if self.split_by == "content_type":
|
||||
split_documents: Dict[str, List[Document]] = {"output_1": [], "output_2": []}
|
||||
|
||||
for doc in documents:
|
||||
if doc.content_type == "text":
|
||||
split_documents["output_1"].append(doc)
|
||||
elif doc.content_type == "table":
|
||||
split_documents["output_2"].append(doc)
|
||||
|
||||
else:
|
||||
assert isinstance(self.metadata_values, list), (
|
||||
"You need to provide metadata_values if you want to split" " a list of Documents by a metadata field."
|
||||
)
|
||||
split_documents = {f"output_{i+1}": [] for i in range(len(self.metadata_values))}
|
||||
for doc in documents:
|
||||
current_metadata_value = doc.meta.get(self.split_by, None)
|
||||
# Disregard current document if it does not contain the provided metadata field
|
||||
if current_metadata_value is not None:
|
||||
try:
|
||||
index = self.metadata_values.index(current_metadata_value)
|
||||
except ValueError:
|
||||
# Disregard current document if current_metadata_value is not in the provided metadata_values
|
||||
continue
|
||||
|
||||
split_documents[f"output_{index+1}"].append(doc)
|
||||
|
||||
return split_documents, "split_documents"
|
@ -645,28 +645,38 @@ class Pipeline(BasePipeline):
|
||||
f"Exception while running node `{node_id}` with input `{node_input}`: {e}, full stack trace: {tb}"
|
||||
)
|
||||
queue.pop(node_id)
|
||||
next_nodes = self.get_next_nodes(node_id, stream_id)
|
||||
for n in next_nodes: # add successor nodes with corresponding inputs to the queue
|
||||
if queue.get(n): # concatenate inputs if it's a join node
|
||||
existing_input = queue[n]
|
||||
if "inputs" not in existing_input.keys():
|
||||
updated_input: dict = {"inputs": [existing_input, node_output], "params": params}
|
||||
if query:
|
||||
updated_input["query"] = query
|
||||
if file_paths:
|
||||
updated_input["file_paths"] = file_paths
|
||||
if labels:
|
||||
updated_input["labels"] = labels
|
||||
if documents:
|
||||
updated_input["documents"] = documents
|
||||
if meta:
|
||||
updated_input["meta"] = meta
|
||||
#
|
||||
if stream_id == "split_documents":
|
||||
for stream_id in [key for key in node_output.keys() if key.startswith("output_")]:
|
||||
current_node_output = {k: v for k, v in node_output.items() if not k.startswith("output_")}
|
||||
current_docs = node_output.pop(stream_id)
|
||||
current_node_output["documents"] = current_docs
|
||||
next_nodes = self.get_next_nodes(node_id, stream_id)
|
||||
for n in next_nodes:
|
||||
queue[n] = current_node_output
|
||||
else:
|
||||
next_nodes = self.get_next_nodes(node_id, stream_id)
|
||||
for n in next_nodes: # add successor nodes with corresponding inputs to the queue
|
||||
if queue.get(n): # concatenate inputs if it's a join node
|
||||
existing_input = queue[n]
|
||||
if "inputs" not in existing_input.keys():
|
||||
updated_input: dict = {"inputs": [existing_input, node_output], "params": params}
|
||||
if query:
|
||||
updated_input["query"] = query
|
||||
if file_paths:
|
||||
updated_input["file_paths"] = file_paths
|
||||
if labels:
|
||||
updated_input["labels"] = labels
|
||||
if documents:
|
||||
updated_input["documents"] = documents
|
||||
if meta:
|
||||
updated_input["meta"] = meta
|
||||
else:
|
||||
existing_input["inputs"].append(node_output)
|
||||
updated_input = existing_input
|
||||
queue[n] = updated_input
|
||||
else:
|
||||
existing_input["inputs"].append(node_output)
|
||||
updated_input = existing_input
|
||||
queue[n] = updated_input
|
||||
else:
|
||||
queue[n] = node_output
|
||||
queue[n] = node_output
|
||||
i = 0
|
||||
else:
|
||||
i += 1 # attempt executing next node in the queue as current `node_id` has unprocessed predecessors
|
||||
|
@ -59,6 +59,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/ImageToTextConverterComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/JoinAnswersComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/JoinDocumentsComponent"
|
||||
},
|
||||
@ -86,6 +89,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/RCIReaderComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/RouteDocumentsComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/SentenceTransformersRankerComponent"
|
||||
},
|
||||
@ -1093,6 +1099,51 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"JoinAnswersComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "JoinAnswers"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"join_mode": {
|
||||
"title": "Join Mode",
|
||||
"default": "concatenate",
|
||||
"type": "string"
|
||||
},
|
||||
"weights": {
|
||||
"title": "Weights",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"top_k_join": {
|
||||
"title": "Top K Join",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"JoinDocumentsComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -1646,6 +1697,47 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"RouteDocumentsComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "RouteDocuments"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"split_by": {
|
||||
"title": "Split By",
|
||||
"default": "content_type",
|
||||
"type": "string"
|
||||
},
|
||||
"metadata_values": {
|
||||
"title": "Metadata Values",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"SentenceTransformersRankerComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -59,6 +59,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/ImageToTextConverterComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/JoinAnswersComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/JoinDocumentsComponent"
|
||||
},
|
||||
@ -95,6 +98,9 @@
|
||||
{
|
||||
"$ref": "#/definitions/SklearnQueryClassifierComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/SplitDocumentListComponent"
|
||||
},
|
||||
{
|
||||
"$ref": "#/definitions/TableReaderComponent"
|
||||
},
|
||||
@ -1093,6 +1099,51 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"JoinAnswersComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "JoinAnswers"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"join_mode": {
|
||||
"title": "Join Mode",
|
||||
"default": "concatenate",
|
||||
"type": "string"
|
||||
},
|
||||
"weights": {
|
||||
"title": "Weights",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "number"
|
||||
}
|
||||
},
|
||||
"top_k_join": {
|
||||
"title": "Top K Join",
|
||||
"type": "integer"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"JoinDocumentsComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@ -1836,6 +1887,47 @@
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"SplitDocumentListComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"title": "Name",
|
||||
"description": "Custom name for the component. Helpful for visualization and debugging.",
|
||||
"type": "string"
|
||||
},
|
||||
"type": {
|
||||
"title": "Type",
|
||||
"description": "Haystack Class name for the component.",
|
||||
"type": "string",
|
||||
"const": "SplitDocumentList"
|
||||
},
|
||||
"params": {
|
||||
"title": "Parameters",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"split_by": {
|
||||
"title": "Split By",
|
||||
"default": "content_type",
|
||||
"type": "string"
|
||||
},
|
||||
"metadata_values": {
|
||||
"title": "Metadata Values",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"description": "Each parameter can reference other components defined in the same YAML file."
|
||||
}
|
||||
},
|
||||
"required": [
|
||||
"type",
|
||||
"name"
|
||||
],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"TableReaderComponent": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -3,10 +3,12 @@ from pathlib import Path
|
||||
import os
|
||||
import json
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
import responses
|
||||
|
||||
from haystack import __version__
|
||||
from haystack import __version__, Document, Answer, JoinAnswers
|
||||
from haystack.document_stores.base import BaseDocumentStore
|
||||
from haystack.document_stores.deepsetcloud import DeepsetCloudDocumentStore
|
||||
from haystack.document_stores.elasticsearch import ElasticsearchDocumentStore
|
||||
@ -17,7 +19,7 @@ from haystack.nodes.retriever.base import BaseRetriever
|
||||
from haystack.nodes.retriever.sparse import ElasticsearchRetriever
|
||||
from haystack.pipelines import Pipeline, DocumentSearchPipeline, RootNode, ExtractiveQAPipeline
|
||||
from haystack.pipelines.base import _PipelineCodeGen
|
||||
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever
|
||||
from haystack.nodes import DensePassageRetriever, EmbeddingRetriever, RouteDocuments
|
||||
|
||||
from conftest import MOCK_DC, DC_API_ENDPOINT, DC_API_KEY, DC_TEST_INDEX, SAMPLES_PATH, deepset_cloud_fixture
|
||||
|
||||
@ -1041,6 +1043,51 @@ def test_documentsearch_document_store_authentication(retriever_with_docs, docum
|
||||
assert kwargs["headers"] == auth_headers
|
||||
|
||||
|
||||
def test_route_documents_by_content_type():
|
||||
# Test routing by content_type
|
||||
docs = [
|
||||
Document(content="text document", content_type="text"),
|
||||
Document(
|
||||
content=pd.DataFrame(columns=["col 1", "col 2"], data=[["row 1", "row 1"], ["row 2", "row 2"]]),
|
||||
content_type="table",
|
||||
),
|
||||
]
|
||||
|
||||
route_documents = RouteDocuments()
|
||||
result, _ = route_documents.run(documents=docs)
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert result["output_1"][0].content_type == "text"
|
||||
assert result["output_2"][0].content_type == "table"
|
||||
|
||||
|
||||
def test_route_documents_by_metafield(test_docs_xs):
|
||||
# Test routing by metadata field
|
||||
docs = [Document.from_dict(doc) if isinstance(doc, dict) else doc for doc in test_docs_xs]
|
||||
route_documents = RouteDocuments(split_by="meta_field", metadata_values=["test1", "test3", "test5"])
|
||||
result, _ = route_documents.run(docs)
|
||||
assert len(result["output_1"]) == 1
|
||||
assert len(result["output_2"]) == 1
|
||||
assert len(result["output_3"]) == 1
|
||||
assert result["output_1"][0].meta["meta_field"] == "test1"
|
||||
assert result["output_2"][0].meta["meta_field"] == "test3"
|
||||
assert result["output_3"][0].meta["meta_field"] == "test5"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("join_mode", ["concatenate", "merge"])
|
||||
def test_join_answers(join_mode):
|
||||
inputs = [{"answers": [Answer(answer="answer 1", score=0.7)]}, {"answers": [Answer(answer="answer 2", score=0.8)]}]
|
||||
|
||||
join_answers = JoinAnswers(join_mode=join_mode)
|
||||
result, _ = join_answers.run(inputs)
|
||||
assert len(result["answers"]) == 2
|
||||
assert result["answers"] == sorted(result["answers"], reverse=True)
|
||||
|
||||
result, _ = join_answers.run(inputs, top_k_join=1)
|
||||
assert len(result["answers"]) == 1
|
||||
assert result["answers"][0].answer == "answer 2"
|
||||
|
||||
|
||||
def clean_faiss_document_store():
|
||||
if Path("existing_faiss_document_store").exists():
|
||||
os.remove("existing_faiss_document_store")
|
||||
|
File diff suppressed because one or more lines are too long
@ -6,7 +6,7 @@ from haystack.utils import launch_es, fetch_archive_from_http, print_answers
|
||||
from haystack.document_stores import ElasticsearchDocumentStore
|
||||
from haystack import Document, Pipeline
|
||||
from haystack.nodes.retriever import TableTextRetriever
|
||||
from haystack.nodes import TableReader
|
||||
from haystack.nodes import TableReader, FARMReader, RouteDocuments, JoinAnswers
|
||||
|
||||
|
||||
def tutorial15_tableqa():
|
||||
@ -115,6 +115,37 @@ def tutorial15_tableqa():
|
||||
prediction = table_qa_pipeline.run("How many twin buildings are under construction?")
|
||||
print_answers(prediction, details="minimum")
|
||||
|
||||
### Pipeline for QA on Combination of Text and Tables
|
||||
# We are using one node for retrieving both texts and tables, the TableTextRetriever.
|
||||
# In order to do question-answering on the Documents coming from the TableTextRetriever, we need to route
|
||||
# Documents of type "text" to a FARMReader ( or alternatively TransformersReader) and Documents of type
|
||||
# "table" to a TableReader.
|
||||
|
||||
text_reader = FARMReader("deepset/roberta-base-squad2")
|
||||
# In order to get meaningful scores from the TableReader, use "deepset/tapas-large-nq-hn-reader" or
|
||||
# "deepset/tapas-large-nq-reader" as TableReader models. The disadvantage of these models is, however,
|
||||
# that they are not capable of doing aggregations over multiple table cells.
|
||||
table_reader = TableReader("deepset/tapas-large-nq-hn-reader")
|
||||
route_documents = RouteDocuments()
|
||||
join_answers = JoinAnswers()
|
||||
|
||||
text_table_qa_pipeline = Pipeline()
|
||||
text_table_qa_pipeline.add_node(component=retriever, name="TableTextRetriever", inputs=["Query"])
|
||||
text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["TableTextRetriever"])
|
||||
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"])
|
||||
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"])
|
||||
text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"])
|
||||
|
||||
# Example query whose answer resides in a text passage
|
||||
predictions = text_table_qa_pipeline.run(query="Who is Aleksandar Trifunovic?")
|
||||
# We can see both text passages and tables as contexts of the predicted answers.
|
||||
print_answers(predictions, details="minimum")
|
||||
|
||||
# Example query whose answer resides in a table
|
||||
predictions = text_table_qa_pipeline.run(query="What is Cuba's national tree?")
|
||||
# We can see both text passages and tables as contexts of the predicted answers.
|
||||
print_answers(predictions, details="minimum")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tutorial15_tableqa()
|
||||
|
Loading…
x
Reference in New Issue
Block a user