mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-08-27 18:06:17 +00:00
Add support for building custom Search Pipelines (#596)
This commit is contained in:
parent
65cf9547d2
commit
e3a68aedaf
@ -3,6 +3,7 @@ import logging
|
||||
import pandas as pd
|
||||
from haystack.schema import Document, Label, MultiLabel
|
||||
from haystack.finder import Finder
|
||||
from haystack.pipeline import Pipeline
|
||||
|
||||
pd.options.display.max_colwidth = 80
|
||||
|
||||
|
@ -28,6 +28,8 @@ class Finder:
|
||||
:param reader: Reader instance
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new"
|
||||
"`Pipeline` class.")
|
||||
self.retriever = retriever
|
||||
self.reader = reader
|
||||
if self.reader is None and self.retriever is None:
|
||||
@ -478,4 +480,3 @@ class Finder:
|
||||
eval_results["reader_topk_no_answer_accuracy"] = None
|
||||
|
||||
return eval_results
|
||||
|
||||
|
189
haystack/pipeline.py
Normal file
189
haystack/pipeline.py
Normal file
@ -0,0 +1,189 @@
|
||||
import networkx as nx
|
||||
from networkx import DiGraph
|
||||
from networkx.drawing.nx_agraph import to_agraph
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from haystack.reader.base import BaseReader
|
||||
from haystack.retriever.base import BaseRetriever
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""
|
||||
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
|
||||
|
||||
Under-the-hood, a pipeline is represented as a directed acyclic graph of component nodes. It enables custom query
|
||||
flows with options to branch queries(eg, extractive qa vs keyword match query), merge candidate documents for a
|
||||
Reader from multiple Retrievers, or re-ranking of candidate documents.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.graph = DiGraph()
|
||||
self.root_node_id = "Query"
|
||||
self.graph.add_node("Query", component=QueryNode())
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
"""
|
||||
Add a new node to the pipeline.
|
||||
|
||||
:param component: The object to be called when the data is passed to the node. It can be a Haystack component
|
||||
(like Retriever, Reader, or Generator) or a user-defined object that implements a run()
|
||||
method to process incoming data from predecessor node.
|
||||
:param name: The name for the node. It must not contain any dots.
|
||||
:param inputs: A list of inputs to the node. If the predecessor node has a single outgoing edge, just the name
|
||||
of node is sufficient. For instance, a 'ElasticsearchRetriever' node would always output a single
|
||||
edge with a list of documents. It can be represented as ["ElasticsearchRetriever"].
|
||||
|
||||
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
|
||||
must be specified explicitly as "QueryClassifier.output_2".
|
||||
|
||||
|
||||
"""
|
||||
self.graph.add_node(name, component=component)
|
||||
|
||||
for i in inputs:
|
||||
if "." in i:
|
||||
[input_node_name, input_edge_name] = i.split(".")
|
||||
assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name."
|
||||
outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges
|
||||
assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, (
|
||||
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
|
||||
f"{outgoing_edges_input_node} outgoing edge(s)."
|
||||
)
|
||||
else:
|
||||
outgoing_edges_input_node = self.graph.nodes[i]["component"].outgoing_edges
|
||||
assert outgoing_edges_input_node == 1, (
|
||||
f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. "
|
||||
f"Please specify the output explicitly."
|
||||
)
|
||||
input_node_name = i
|
||||
input_edge_name = "output_1"
|
||||
self.graph.add_edge(input_node_name, name, label=input_edge_name)
|
||||
|
||||
def run(self, **kwargs):
|
||||
has_next_node = True
|
||||
current_node_id = self.root_node_id
|
||||
input_dict = kwargs
|
||||
output_dict = None
|
||||
|
||||
while has_next_node:
|
||||
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
|
||||
input_dict = output_dict
|
||||
next_nodes = self._get_next_nodes(current_node_id, stream_id)
|
||||
|
||||
if len(next_nodes) > 1:
|
||||
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
|
||||
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
|
||||
raise NotImplementedError(
|
||||
"The current pipeline does not support multiple levels of parallel nodes."
|
||||
)
|
||||
inputs_for_join_node = {"inputs": []}
|
||||
for n_id in next_nodes:
|
||||
output = self.graph.nodes[n_id]["component"].run(**input_dict)
|
||||
inputs_for_join_node["inputs"].append(output)
|
||||
input_dict = inputs_for_join_node
|
||||
current_node_id = join_node_id
|
||||
elif len(next_nodes) == 1:
|
||||
current_node_id = next_nodes[0]
|
||||
else:
|
||||
has_next_node = False
|
||||
|
||||
return output_dict
|
||||
|
||||
def _get_next_nodes(self, node_id: str, stream_id: str):
|
||||
current_node_edges = self.graph.edges(node_id, data=True)
|
||||
next_nodes = [
|
||||
next_node
|
||||
for _, next_node, data in current_node_edges
|
||||
if not stream_id or data["label"] == stream_id
|
||||
]
|
||||
return next_nodes
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
"""
|
||||
Create a Graphviz visualization of the pipeline.
|
||||
|
||||
:param path: the path to save the image.
|
||||
"""
|
||||
try:
|
||||
import pygraphviz
|
||||
except ImportError:
|
||||
raise ImportError(f"Could not import `pygraphviz`. Please install via: \n"
|
||||
f"pip install pygraphviz\n"
|
||||
f"(You might need to run this first: apt install libgraphviz-dev graphviz )")
|
||||
|
||||
graphviz = to_agraph(self.graph)
|
||||
graphviz.layout("dot")
|
||||
graphviz.draw(path)
|
||||
|
||||
|
||||
class ExtractiveQAPipeline:
|
||||
def __init__(self, reader: BaseReader, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for Extractive Question Answering.
|
||||
|
||||
:param reader: Reader instance
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
|
||||
|
||||
def run(self, question, top_k_retriever=5, top_k_reader=5):
|
||||
output = self.pipeline.run(question=question,
|
||||
top_k_retriever=top_k_retriever,
|
||||
top_k_reader=top_k_reader)
|
||||
return output
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path)
|
||||
|
||||
|
||||
class DocumentSearchPipeline:
|
||||
def __init__(self, retriever: BaseRetriever):
|
||||
"""
|
||||
Initialize a Pipeline for semantic document search.
|
||||
|
||||
:param retriever: Retriever instance
|
||||
"""
|
||||
self.pipeline = Pipeline()
|
||||
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
|
||||
|
||||
def run(self, question, top_k_retriever=5):
|
||||
output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever)
|
||||
document_dicts = [doc.to_dict() for doc in output["documents"]]
|
||||
output["documents"] = document_dicts
|
||||
return output
|
||||
|
||||
def add_node(self, component, name: str, inputs: List[str]):
|
||||
self.pipeline.add_node(component=component, name=name, inputs=inputs)
|
||||
|
||||
def draw(self, path: Path = Path("pipeline.png")):
|
||||
self.pipeline.draw(path)
|
||||
|
||||
|
||||
class QueryNode:
|
||||
outgoing_edges = 1
|
||||
|
||||
def run(self, **kwargs):
|
||||
return kwargs, "output_1"
|
||||
|
||||
|
||||
class JoinDocuments:
|
||||
outgoing_edges = 1
|
||||
|
||||
def __init__(self, join_mode="concatenate"):
|
||||
pass
|
||||
|
||||
def run(self, **kwargs):
|
||||
inputs = kwargs["inputs"]
|
||||
|
||||
documents = []
|
||||
for i, _ in inputs:
|
||||
documents.extend(i["documents"])
|
||||
output = {
|
||||
"question": inputs[0][0]["question"],
|
||||
"documents": documents
|
||||
}
|
||||
return output, "output_1"
|
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
from scipy.special import expit
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from haystack import Document
|
||||
@ -8,6 +9,7 @@ from haystack import Document
|
||||
|
||||
class BaseReader(ABC):
|
||||
return_no_answers: bool
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
@ -44,3 +46,18 @@ class BaseReader(ABC):
|
||||
"document_id": None,
|
||||
"meta": None,}
|
||||
return no_ans_prediction, max_no_ans_gap
|
||||
|
||||
def run(self, question: str, documents: List[Document], top_k: Optional[int] = None):
|
||||
if documents:
|
||||
results = self.predict(question=question, documents=documents, top_k=top_k)
|
||||
else:
|
||||
results = {"answers": []}
|
||||
|
||||
# Add corresponding document_name and more meta data, if an answer contains the document_id
|
||||
for ans in results["answers"]:
|
||||
ans["meta"] = {}
|
||||
for doc in documents:
|
||||
if doc.id == ans["document_id"]:
|
||||
ans["meta"] = deepcopy(doc.meta)
|
||||
|
||||
return results, "output_1"
|
||||
|
@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
import logging
|
||||
from time import perf_counter
|
||||
from functools import wraps
|
||||
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BaseRetriever(ABC):
|
||||
document_store: BaseDocumentStore
|
||||
outgoing_edges = 1
|
||||
|
||||
@abstractmethod
|
||||
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
|
||||
@ -164,4 +165,23 @@ class BaseRetriever(ABC):
|
||||
if return_preds:
|
||||
return {"metrics": metrics, "predictions": predictions}
|
||||
else:
|
||||
return metrics
|
||||
return metrics
|
||||
|
||||
def run(
|
||||
self,
|
||||
question: str,
|
||||
filters: Optional[dict] = None,
|
||||
top_k_retriever: Optional[int] = None,
|
||||
top_k_reader: Optional[int] = None,
|
||||
):
|
||||
if top_k_retriever:
|
||||
documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever)
|
||||
else:
|
||||
documents = self.retrieve(query=question, filters=filters)
|
||||
output = {
|
||||
"question": question,
|
||||
"documents": documents,
|
||||
"top_k": top_k_reader
|
||||
}
|
||||
|
||||
return output, "output_1"
|
||||
|
@ -71,6 +71,11 @@ class Document:
|
||||
|
||||
return cls(**_new_doc)
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
class Label:
|
||||
def __init__(self, question: str,
|
||||
@ -140,6 +145,11 @@ class Label:
|
||||
str(self.no_answer) +
|
||||
str(self.model_id))
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
class MultiLabel:
|
||||
def __init__(self, question: str,
|
||||
@ -181,4 +191,10 @@ class MultiLabel:
|
||||
return cls(**dict)
|
||||
|
||||
def to_dict(self):
|
||||
return self.__dict__
|
||||
return self.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.to_dict())
|
||||
|
||||
def __str__(self):
|
||||
return str(self.to_dict())
|
@ -22,3 +22,4 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
|
||||
httptools
|
||||
nltk
|
||||
more_itertools
|
||||
networkx
|
62
test/test_pipeline.py
Normal file
62
test/test_pipeline.py
Normal file
@ -0,0 +1,62 @@
|
||||
import pytest
|
||||
|
||||
from haystack.pipeline import ExtractiveQAPipeline, Pipeline
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
|
||||
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = Pipeline()
|
||||
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"])
|
||||
|
||||
with pytest.raises(Exception):
|
||||
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
|
||||
assert prediction is not None
|
||||
assert prediction["question"] == "Who lives in Berlin?"
|
||||
assert prediction["answers"][0]["answer"] == "Carla"
|
||||
assert prediction["answers"][0]["probability"] <= 1
|
||||
assert prediction["answers"][0]["probability"] >= 0
|
||||
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
|
||||
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
|
||||
|
||||
assert len(prediction["answers"]) == 3
|
||||
|
||||
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
|
||||
|
||||
assert prediction["answers"][0]["offset_start"] == 11
|
||||
assert prediction["answers"][0]["offset_end"] == 16
|
||||
start = prediction["answers"][0]["offset_start"]
|
||||
end = prediction["answers"][0]["offset_end"]
|
||||
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.elasticsearch
|
||||
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
|
||||
def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
|
||||
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
|
||||
query = "testing finder"
|
||||
prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1)
|
||||
assert prediction is not None
|
||||
assert len(prediction["answers"]) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user