mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-28 10:26:27 +00:00
Add support for building custom Search Pipelines (#596)
This commit is contained in:
parent
65cf9547d2
commit
e3a68aedaf
@ -3,6 +3,7 @@ import logging
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from haystack.schema import Document, Label, MultiLabel
|
from haystack.schema import Document, Label, MultiLabel
|
||||||
from haystack.finder import Finder
|
from haystack.finder import Finder
|
||||||
|
from haystack.pipeline import Pipeline
|
||||||
|
|
||||||
pd.options.display.max_colwidth = 80
|
pd.options.display.max_colwidth = 80
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ class Finder:
|
|||||||
:param reader: Reader instance
|
:param reader: Reader instance
|
||||||
:param retriever: Retriever instance
|
:param retriever: Retriever instance
|
||||||
"""
|
"""
|
||||||
|
logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new"
|
||||||
|
"`Pipeline` class.")
|
||||||
self.retriever = retriever
|
self.retriever = retriever
|
||||||
self.reader = reader
|
self.reader = reader
|
||||||
if self.reader is None and self.retriever is None:
|
if self.reader is None and self.retriever is None:
|
||||||
@ -478,4 +480,3 @@ class Finder:
|
|||||||
eval_results["reader_topk_no_answer_accuracy"] = None
|
eval_results["reader_topk_no_answer_accuracy"] = None
|
||||||
|
|
||||||
return eval_results
|
return eval_results
|
||||||
|
|
||||||
|
189
haystack/pipeline.py
Normal file
189
haystack/pipeline.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
import networkx as nx
|
||||||
|
from networkx import DiGraph
|
||||||
|
from networkx.drawing.nx_agraph import to_agraph
|
||||||
|
from typing import List
|
||||||
|
from pathlib import Path
|
||||||
|
from haystack.reader.base import BaseReader
|
||||||
|
from haystack.retriever.base import BaseRetriever
|
||||||
|
|
||||||
|
|
||||||
|
class Pipeline:
|
||||||
|
"""
|
||||||
|
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
|
||||||
|
|
||||||
|
Under-the-hood, a pipeline is represented as a directed acyclic graph of component nodes. It enables custom query
|
||||||
|
flows with options to branch queries(eg, extractive qa vs keyword match query), merge candidate documents for a
|
||||||
|
Reader from multiple Retrievers, or re-ranking of candidate documents.
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.graph = DiGraph()
|
||||||
|
self.root_node_id = "Query"
|
||||||
|
self.graph.add_node("Query", component=QueryNode())
|
||||||
|
|
||||||
|
def add_node(self, component, name: str, inputs: List[str]):
|
||||||
|
"""
|
||||||
|
Add a new node to the pipeline.
|
||||||
|
|
||||||
|
:param component: The object to be called when the data is passed to the node. It can be a Haystack component
|
||||||
|
(like Retriever, Reader, or Generator) or a user-defined object that implements a run()
|
||||||
|
method to process incoming data from predecessor node.
|
||||||
|
:param name: The name for the node. It must not contain any dots.
|
||||||
|
:param inputs: A list of inputs to the node. If the predecessor node has a single outgoing edge, just the name
|
||||||
|
of node is sufficient. For instance, a 'ElasticsearchRetriever' node would always output a single
|
||||||
|
edge with a list of documents. It can be represented as ["ElasticsearchRetriever"].
|
||||||
|
|
||||||
|
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
|
||||||
|
must be specified explicitly as "QueryClassifier.output_2".
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.graph.add_node(name, component=component)
|
||||||
|
|
||||||
|
for i in inputs:
|
||||||
|
if "." in i:
|
||||||
|
[input_node_name, input_edge_name] = i.split(".")
|
||||||
|
assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name."
|
||||||
|
outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges
|
||||||
|
assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, (
|
||||||
|
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
|
||||||
|
f"{outgoing_edges_input_node} outgoing edge(s)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
outgoing_edges_input_node = self.graph.nodes[i]["component"].outgoing_edges
|
||||||
|
assert outgoing_edges_input_node == 1, (
|
||||||
|
f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. "
|
||||||
|
f"Please specify the output explicitly."
|
||||||
|
)
|
||||||
|
input_node_name = i
|
||||||
|
input_edge_name = "output_1"
|
||||||
|
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||||
|
|
||||||
|
def run(self, **kwargs):
|
||||||
|
has_next_node = True
|
||||||
|
current_node_id = self.root_node_id
|
||||||
|
input_dict = kwargs
|
||||||
|
output_dict = None
|
||||||
|
|
||||||
|
while has_next_node:
|
||||||
|
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
|
||||||
|
input_dict = output_dict
|
||||||
|
next_nodes = self._get_next_nodes(current_node_id, stream_id)
|
||||||
|
|
||||||
|
if len(next_nodes) > 1:
|
||||||
|
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
|
||||||
|
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The current pipeline does not support multiple levels of parallel nodes."
|
||||||
|
)
|
||||||
|
inputs_for_join_node = {"inputs": []}
|
||||||
|
for n_id in next_nodes:
|
||||||
|
output = self.graph.nodes[n_id]["component"].run(**input_dict)
|
||||||
|
inputs_for_join_node["inputs"].append(output)
|
||||||
|
input_dict = inputs_for_join_node
|
||||||
|
current_node_id = join_node_id
|
||||||
|
elif len(next_nodes) == 1:
|
||||||
|
current_node_id = next_nodes[0]
|
||||||
|
else:
|
||||||
|
has_next_node = False
|
||||||
|
|
||||||
|
return output_dict
|
||||||
|
|
||||||
|
def _get_next_nodes(self, node_id: str, stream_id: str):
|
||||||
|
current_node_edges = self.graph.edges(node_id, data=True)
|
||||||
|
next_nodes = [
|
||||||
|
next_node
|
||||||
|
for _, next_node, data in current_node_edges
|
||||||
|
if not stream_id or data["label"] == stream_id
|
||||||
|
]
|
||||||
|
return next_nodes
|
||||||
|
|
||||||
|
def draw(self, path: Path = Path("pipeline.png")):
|
||||||
|
"""
|
||||||
|
Create a Graphviz visualization of the pipeline.
|
||||||
|
|
||||||
|
:param path: the path to save the image.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import pygraphviz
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(f"Could not import `pygraphviz`. Please install via: \n"
|
||||||
|
f"pip install pygraphviz\n"
|
||||||
|
f"(You might need to run this first: apt install libgraphviz-dev graphviz )")
|
||||||
|
|
||||||
|
graphviz = to_agraph(self.graph)
|
||||||
|
graphviz.layout("dot")
|
||||||
|
graphviz.draw(path)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractiveQAPipeline:
|
||||||
|
def __init__(self, reader: BaseReader, retriever: BaseRetriever):
|
||||||
|
"""
|
||||||
|
Initialize a Pipeline for Extractive Question Answering.
|
||||||
|
|
||||||
|
:param reader: Reader instance
|
||||||
|
:param retriever: Retriever instance
|
||||||
|
"""
|
||||||
|
self.pipeline = Pipeline()
|
||||||
|
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||||
|
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
|
||||||
|
|
||||||
|
def run(self, question, top_k_retriever=5, top_k_reader=5):
|
||||||
|
output = self.pipeline.run(question=question,
|
||||||
|
top_k_retriever=top_k_retriever,
|
||||||
|
top_k_reader=top_k_reader)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_node(self, component, name: str, inputs: List[str]):
|
||||||
|
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||||
|
|
||||||
|
def draw(self, path: Path = Path("pipeline.png")):
|
||||||
|
self.pipeline.draw(path)
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentSearchPipeline:
|
||||||
|
def __init__(self, retriever: BaseRetriever):
|
||||||
|
"""
|
||||||
|
Initialize a Pipeline for semantic document search.
|
||||||
|
|
||||||
|
:param retriever: Retriever instance
|
||||||
|
"""
|
||||||
|
self.pipeline = Pipeline()
|
||||||
|
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||||
|
|
||||||
|
def run(self, question, top_k_retriever=5):
|
||||||
|
output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever)
|
||||||
|
document_dicts = [doc.to_dict() for doc in output["documents"]]
|
||||||
|
output["documents"] = document_dicts
|
||||||
|
return output
|
||||||
|
|
||||||
|
def add_node(self, component, name: str, inputs: List[str]):
|
||||||
|
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||||
|
|
||||||
|
def draw(self, path: Path = Path("pipeline.png")):
|
||||||
|
self.pipeline.draw(path)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryNode:
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def run(self, **kwargs):
|
||||||
|
return kwargs, "output_1"
|
||||||
|
|
||||||
|
|
||||||
|
class JoinDocuments:
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
|
def __init__(self, join_mode="concatenate"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run(self, **kwargs):
|
||||||
|
inputs = kwargs["inputs"]
|
||||||
|
|
||||||
|
documents = []
|
||||||
|
for i, _ in inputs:
|
||||||
|
documents.extend(i["documents"])
|
||||||
|
output = {
|
||||||
|
"question": inputs[0][0]["question"],
|
||||||
|
"documents": documents
|
||||||
|
}
|
||||||
|
return output, "output_1"
|
@ -1,6 +1,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.special import expit
|
from scipy.special import expit
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from copy import deepcopy
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from haystack import Document
|
from haystack import Document
|
||||||
@ -8,6 +9,7 @@ from haystack import Document
|
|||||||
|
|
||||||
class BaseReader(ABC):
|
class BaseReader(ABC):
|
||||||
return_no_answers: bool
|
return_no_answers: bool
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
@ -44,3 +46,18 @@ class BaseReader(ABC):
|
|||||||
"document_id": None,
|
"document_id": None,
|
||||||
"meta": None,}
|
"meta": None,}
|
||||||
return no_ans_prediction, max_no_ans_gap
|
return no_ans_prediction, max_no_ans_gap
|
||||||
|
|
||||||
|
def run(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||||
|
if documents:
|
||||||
|
results = self.predict(question=question, documents=documents, top_k=top_k)
|
||||||
|
else:
|
||||||
|
results = {"answers": []}
|
||||||
|
|
||||||
|
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||||
|
for ans in results["answers"]:
|
||||||
|
ans["meta"] = {}
|
||||||
|
for doc in documents:
|
||||||
|
if doc.id == ans["document_id"]:
|
||||||
|
ans["meta"] = deepcopy(doc.meta)
|
||||||
|
|
||||||
|
return results, "output_1"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List
|
from typing import List, Optional
|
||||||
import logging
|
import logging
|
||||||
from time import perf_counter
|
from time import perf_counter
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class BaseRetriever(ABC):
|
class BaseRetriever(ABC):
|
||||||
document_store: BaseDocumentStore
|
document_store: BaseDocumentStore
|
||||||
|
outgoing_edges = 1
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||||
@ -165,3 +166,22 @@ class BaseRetriever(ABC):
|
|||||||
return {"metrics": metrics, "predictions": predictions}
|
return {"metrics": metrics, "predictions": predictions}
|
||||||
else:
|
else:
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
question: str,
|
||||||
|
filters: Optional[dict] = None,
|
||||||
|
top_k_retriever: Optional[int] = None,
|
||||||
|
top_k_reader: Optional[int] = None,
|
||||||
|
):
|
||||||
|
if top_k_retriever:
|
||||||
|
documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever)
|
||||||
|
else:
|
||||||
|
documents = self.retrieve(query=question, filters=filters)
|
||||||
|
output = {
|
||||||
|
"question": question,
|
||||||
|
"documents": documents,
|
||||||
|
"top_k": top_k_reader
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, "output_1"
|
||||||
|
@ -71,6 +71,11 @@ class Document:
|
|||||||
|
|
||||||
return cls(**_new_doc)
|
return cls(**_new_doc)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.to_dict())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.to_dict())
|
||||||
|
|
||||||
class Label:
|
class Label:
|
||||||
def __init__(self, question: str,
|
def __init__(self, question: str,
|
||||||
@ -140,6 +145,11 @@ class Label:
|
|||||||
str(self.no_answer) +
|
str(self.no_answer) +
|
||||||
str(self.model_id))
|
str(self.model_id))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.to_dict())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.to_dict())
|
||||||
|
|
||||||
class MultiLabel:
|
class MultiLabel:
|
||||||
def __init__(self, question: str,
|
def __init__(self, question: str,
|
||||||
@ -182,3 +192,9 @@ class MultiLabel:
|
|||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
return self.__dict__
|
return self.__dict__
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.to_dict())
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return str(self.to_dict())
|
@ -22,3 +22,4 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
|
|||||||
httptools
|
httptools
|
||||||
nltk
|
nltk
|
||||||
more_itertools
|
more_itertools
|
||||||
|
networkx
|
62
test/test_pipeline.py
Normal file
62
test/test_pipeline.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from haystack.pipeline import ExtractiveQAPipeline, Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
|
||||||
|
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
|
||||||
|
pipeline = Pipeline()
|
||||||
|
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"])
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"])
|
||||||
|
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
|
def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||||
|
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||||
|
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||||
|
assert prediction is not None
|
||||||
|
assert prediction["question"] == "Who lives in Berlin?"
|
||||||
|
assert prediction["answers"][0]["answer"] == "Carla"
|
||||||
|
assert prediction["answers"][0]["probability"] <= 1
|
||||||
|
assert prediction["answers"][0]["probability"] >= 0
|
||||||
|
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||||
|
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||||
|
|
||||||
|
assert len(prediction["answers"]) == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
|
def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs):
|
||||||
|
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||||
|
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
|
||||||
|
|
||||||
|
assert prediction["answers"][0]["offset_start"] == 11
|
||||||
|
assert prediction["answers"][0]["offset_end"] == 16
|
||||||
|
start = prediction["answers"][0]["offset_start"]
|
||||||
|
end = prediction["answers"][0]["offset_end"]
|
||||||
|
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
@pytest.mark.elasticsearch
|
||||||
|
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||||
|
def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
|
||||||
|
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||||
|
query = "testing finder"
|
||||||
|
prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1)
|
||||||
|
assert prediction is not None
|
||||||
|
assert len(prediction["answers"]) == 1
|
||||||
|
|
Loading…
x
Reference in New Issue
Block a user