Add support for building custom Search Pipelines (#596)

This commit is contained in:
Tanay Soni 2020-11-20 17:41:08 +01:00 committed by GitHub
parent 65cf9547d2
commit e3a68aedaf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 311 additions and 4 deletions

View File

@ -3,6 +3,7 @@ import logging
import pandas as pd
from haystack.schema import Document, Label, MultiLabel
from haystack.finder import Finder
from haystack.pipeline import Pipeline
pd.options.display.max_colwidth = 80

View File

@ -28,6 +28,8 @@ class Finder:
:param reader: Reader instance
:param retriever: Retriever instance
"""
logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new"
"`Pipeline` class.")
self.retriever = retriever
self.reader = reader
if self.reader is None and self.retriever is None:
@ -478,4 +480,3 @@ class Finder:
eval_results["reader_topk_no_answer_accuracy"] = None
return eval_results

189
haystack/pipeline.py Normal file
View File

@ -0,0 +1,189 @@
import networkx as nx
from networkx import DiGraph
from networkx.drawing.nx_agraph import to_agraph
from typing import List
from pathlib import Path
from haystack.reader.base import BaseReader
from haystack.retriever.base import BaseRetriever
class Pipeline:
"""
Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components.
Under-the-hood, a pipeline is represented as a directed acyclic graph of component nodes. It enables custom query
flows with options to branch queries(eg, extractive qa vs keyword match query), merge candidate documents for a
Reader from multiple Retrievers, or re-ranking of candidate documents.
"""
def __init__(self):
self.graph = DiGraph()
self.root_node_id = "Query"
self.graph.add_node("Query", component=QueryNode())
def add_node(self, component, name: str, inputs: List[str]):
"""
Add a new node to the pipeline.
:param component: The object to be called when the data is passed to the node. It can be a Haystack component
(like Retriever, Reader, or Generator) or a user-defined object that implements a run()
method to process incoming data from predecessor node.
:param name: The name for the node. It must not contain any dots.
:param inputs: A list of inputs to the node. If the predecessor node has a single outgoing edge, just the name
of node is sufficient. For instance, a 'ElasticsearchRetriever' node would always output a single
edge with a list of documents. It can be represented as ["ElasticsearchRetriever"].
In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output
must be specified explicitly as "QueryClassifier.output_2".
"""
self.graph.add_node(name, component=component)
for i in inputs:
if "." in i:
[input_node_name, input_edge_name] = i.split(".")
assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name."
outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges
assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, (
f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has "
f"{outgoing_edges_input_node} outgoing edge(s)."
)
else:
outgoing_edges_input_node = self.graph.nodes[i]["component"].outgoing_edges
assert outgoing_edges_input_node == 1, (
f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. "
f"Please specify the output explicitly."
)
input_node_name = i
input_edge_name = "output_1"
self.graph.add_edge(input_node_name, name, label=input_edge_name)
def run(self, **kwargs):
has_next_node = True
current_node_id = self.root_node_id
input_dict = kwargs
output_dict = None
while has_next_node:
output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict)
input_dict = output_dict
next_nodes = self._get_next_nodes(current_node_id, stream_id)
if len(next_nodes) > 1:
join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0]
if set(self.graph.predecessors(join_node_id)) != set(next_nodes):
raise NotImplementedError(
"The current pipeline does not support multiple levels of parallel nodes."
)
inputs_for_join_node = {"inputs": []}
for n_id in next_nodes:
output = self.graph.nodes[n_id]["component"].run(**input_dict)
inputs_for_join_node["inputs"].append(output)
input_dict = inputs_for_join_node
current_node_id = join_node_id
elif len(next_nodes) == 1:
current_node_id = next_nodes[0]
else:
has_next_node = False
return output_dict
def _get_next_nodes(self, node_id: str, stream_id: str):
current_node_edges = self.graph.edges(node_id, data=True)
next_nodes = [
next_node
for _, next_node, data in current_node_edges
if not stream_id or data["label"] == stream_id
]
return next_nodes
def draw(self, path: Path = Path("pipeline.png")):
"""
Create a Graphviz visualization of the pipeline.
:param path: the path to save the image.
"""
try:
import pygraphviz
except ImportError:
raise ImportError(f"Could not import `pygraphviz`. Please install via: \n"
f"pip install pygraphviz\n"
f"(You might need to run this first: apt install libgraphviz-dev graphviz )")
graphviz = to_agraph(self.graph)
graphviz.layout("dot")
graphviz.draw(path)
class ExtractiveQAPipeline:
def __init__(self, reader: BaseReader, retriever: BaseRetriever):
"""
Initialize a Pipeline for Extractive Question Answering.
:param reader: Reader instance
:param retriever: Retriever instance
"""
self.pipeline = Pipeline()
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"])
def run(self, question, top_k_retriever=5, top_k_reader=5):
output = self.pipeline.run(question=question,
top_k_retriever=top_k_retriever,
top_k_reader=top_k_reader)
return output
def add_node(self, component, name: str, inputs: List[str]):
self.pipeline.add_node(component=component, name=name, inputs=inputs)
def draw(self, path: Path = Path("pipeline.png")):
self.pipeline.draw(path)
class DocumentSearchPipeline:
def __init__(self, retriever: BaseRetriever):
"""
Initialize a Pipeline for semantic document search.
:param retriever: Retriever instance
"""
self.pipeline = Pipeline()
self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"])
def run(self, question, top_k_retriever=5):
output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever)
document_dicts = [doc.to_dict() for doc in output["documents"]]
output["documents"] = document_dicts
return output
def add_node(self, component, name: str, inputs: List[str]):
self.pipeline.add_node(component=component, name=name, inputs=inputs)
def draw(self, path: Path = Path("pipeline.png")):
self.pipeline.draw(path)
class QueryNode:
outgoing_edges = 1
def run(self, **kwargs):
return kwargs, "output_1"
class JoinDocuments:
outgoing_edges = 1
def __init__(self, join_mode="concatenate"):
pass
def run(self, **kwargs):
inputs = kwargs["inputs"]
documents = []
for i, _ in inputs:
documents.extend(i["documents"])
output = {
"question": inputs[0][0]["question"],
"documents": documents
}
return output, "output_1"

View File

@ -1,6 +1,7 @@
import numpy as np
from scipy.special import expit
from abc import ABC, abstractmethod
from copy import deepcopy
from typing import List, Optional, Sequence
from haystack import Document
@ -8,6 +9,7 @@ from haystack import Document
class BaseReader(ABC):
return_no_answers: bool
outgoing_edges = 1
@abstractmethod
def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None):
@ -44,3 +46,18 @@ class BaseReader(ABC):
"document_id": None,
"meta": None,}
return no_ans_prediction, max_no_ans_gap
def run(self, question: str, documents: List[Document], top_k: Optional[int] = None):
if documents:
results = self.predict(question=question, documents=documents, top_k=top_k)
else:
results = {"answers": []}
# Add corresponding document_name and more meta data, if an answer contains the document_id
for ans in results["answers"]:
ans["meta"] = {}
for doc in documents:
if doc.id == ans["document_id"]:
ans["meta"] = deepcopy(doc.meta)
return results, "output_1"

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List
from typing import List, Optional
import logging
from time import perf_counter
from functools import wraps
@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
class BaseRetriever(ABC):
document_store: BaseDocumentStore
outgoing_edges = 1
@abstractmethod
def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]:
@ -164,4 +165,23 @@ class BaseRetriever(ABC):
if return_preds:
return {"metrics": metrics, "predictions": predictions}
else:
return metrics
return metrics
def run(
self,
question: str,
filters: Optional[dict] = None,
top_k_retriever: Optional[int] = None,
top_k_reader: Optional[int] = None,
):
if top_k_retriever:
documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever)
else:
documents = self.retrieve(query=question, filters=filters)
output = {
"question": question,
"documents": documents,
"top_k": top_k_reader
}
return output, "output_1"

View File

@ -71,6 +71,11 @@ class Document:
return cls(**_new_doc)
def __repr__(self):
return str(self.to_dict())
def __str__(self):
return str(self.to_dict())
class Label:
def __init__(self, question: str,
@ -140,6 +145,11 @@ class Label:
str(self.no_answer) +
str(self.model_id))
def __repr__(self):
return str(self.to_dict())
def __str__(self):
return str(self.to_dict())
class MultiLabel:
def __init__(self, question: str,
@ -181,4 +191,10 @@ class MultiLabel:
return cls(**dict)
def to_dict(self):
return self.__dict__
return self.__dict__
def __repr__(self):
return str(self.to_dict())
def __str__(self):
return str(self.to_dict())

View File

@ -22,3 +22,4 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin'
httptools
nltk
more_itertools
networkx

62
test/test_pipeline.py Normal file
View File

@ -0,0 +1,62 @@
import pytest
from haystack.pipeline import ExtractiveQAPipeline, Pipeline
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True)
def test_graph_creation(reader, retriever_with_docs, document_store_with_docs):
pipeline = Pipeline()
pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"])
with pytest.raises(AssertionError):
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"])
with pytest.raises(AssertionError):
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"])
with pytest.raises(Exception):
pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"])
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs):
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3)
assert prediction is not None
assert prediction["question"] == "Who lives in Berlin?"
assert prediction["answers"][0]["answer"] == "Carla"
assert prediction["answers"][0]["probability"] <= 1
assert prediction["answers"][0]["probability"] >= 0
assert prediction["answers"][0]["meta"]["meta_field"] == "test1"
assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin"
assert len(prediction["answers"]) == 3
@pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs):
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5)
assert prediction["answers"][0]["offset_start"] == 11
assert prediction["answers"][0]["offset_end"] == 16
start = prediction["answers"][0]["offset_start"]
end = prediction["answers"][0]["offset_end"]
assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"]
@pytest.mark.slow
@pytest.mark.elasticsearch
@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True)
def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs):
pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs)
query = "testing finder"
prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1)
assert prediction is not None
assert len(prediction["answers"]) == 1