From e3a68aedaff093a44b43358e77ab9a5290ebb7d6 Mon Sep 17 00:00:00 2001 From: Tanay Soni Date: Fri, 20 Nov 2020 17:41:08 +0100 Subject: [PATCH] Add support for building custom Search Pipelines (#596) --- haystack/__init__.py | 1 + haystack/finder.py | 3 +- haystack/pipeline.py | 189 +++++++++++++++++++++++++++++++++++++ haystack/reader/base.py | 17 ++++ haystack/retriever/base.py | 24 ++++- haystack/schema.py | 18 +++- requirements.txt | 1 + test/test_pipeline.py | 62 ++++++++++++ 8 files changed, 311 insertions(+), 4 deletions(-) create mode 100644 haystack/pipeline.py create mode 100644 test/test_pipeline.py diff --git a/haystack/__init__.py b/haystack/__init__.py index 4c7042f28..0ca42f5e7 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -3,6 +3,7 @@ import logging import pandas as pd from haystack.schema import Document, Label, MultiLabel from haystack.finder import Finder +from haystack.pipeline import Pipeline pd.options.display.max_colwidth = 80 diff --git a/haystack/finder.py b/haystack/finder.py index 49f0232dc..31a62ca9e 100644 --- a/haystack/finder.py +++ b/haystack/finder.py @@ -28,6 +28,8 @@ class Finder: :param reader: Reader instance :param retriever: Retriever instance """ + logger.warning("The 'Finder' class will be deprecated in the next Haystack release in favour of the new" + "`Pipeline` class.") self.retriever = retriever self.reader = reader if self.reader is None and self.retriever is None: @@ -478,4 +480,3 @@ class Finder: eval_results["reader_topk_no_answer_accuracy"] = None return eval_results - diff --git a/haystack/pipeline.py b/haystack/pipeline.py new file mode 100644 index 000000000..0e68a2b71 --- /dev/null +++ b/haystack/pipeline.py @@ -0,0 +1,189 @@ +import networkx as nx +from networkx import DiGraph +from networkx.drawing.nx_agraph import to_agraph +from typing import List +from pathlib import Path +from haystack.reader.base import BaseReader +from haystack.retriever.base import BaseRetriever + + +class Pipeline: + """ + Pipeline brings together building blocks to build a complex search pipeline with Haystack & user-defined components. + + Under-the-hood, a pipeline is represented as a directed acyclic graph of component nodes. It enables custom query + flows with options to branch queries(eg, extractive qa vs keyword match query), merge candidate documents for a + Reader from multiple Retrievers, or re-ranking of candidate documents. + """ + def __init__(self): + self.graph = DiGraph() + self.root_node_id = "Query" + self.graph.add_node("Query", component=QueryNode()) + + def add_node(self, component, name: str, inputs: List[str]): + """ + Add a new node to the pipeline. + + :param component: The object to be called when the data is passed to the node. It can be a Haystack component + (like Retriever, Reader, or Generator) or a user-defined object that implements a run() + method to process incoming data from predecessor node. + :param name: The name for the node. It must not contain any dots. + :param inputs: A list of inputs to the node. If the predecessor node has a single outgoing edge, just the name + of node is sufficient. For instance, a 'ElasticsearchRetriever' node would always output a single + edge with a list of documents. It can be represented as ["ElasticsearchRetriever"]. + + In cases when the predecessor node has multiple outputs, e.g., a "QueryClassifier", the output + must be specified explicitly as "QueryClassifier.output_2". + + + """ + self.graph.add_node(name, component=component) + + for i in inputs: + if "." in i: + [input_node_name, input_edge_name] = i.split(".") + assert "output_" in input_edge_name, f"'{input_edge_name}' is not a valid edge name." + outgoing_edges_input_node = self.graph.nodes[input_node_name]["component"].outgoing_edges + assert int(input_edge_name.split("_")[1]) <= outgoing_edges_input_node, ( + f"Cannot connect '{input_edge_name}' from '{input_node_name}' as it only has " + f"{outgoing_edges_input_node} outgoing edge(s)." + ) + else: + outgoing_edges_input_node = self.graph.nodes[i]["component"].outgoing_edges + assert outgoing_edges_input_node == 1, ( + f"Adding an edge from {i} to {name} is ambiguous as {i} has {outgoing_edges_input_node} edges. " + f"Please specify the output explicitly." + ) + input_node_name = i + input_edge_name = "output_1" + self.graph.add_edge(input_node_name, name, label=input_edge_name) + + def run(self, **kwargs): + has_next_node = True + current_node_id = self.root_node_id + input_dict = kwargs + output_dict = None + + while has_next_node: + output_dict, stream_id = self.graph.nodes[current_node_id]["component"].run(**input_dict) + input_dict = output_dict + next_nodes = self._get_next_nodes(current_node_id, stream_id) + + if len(next_nodes) > 1: + join_node_id = list(nx.neighbors(self.graph, next_nodes[0]))[0] + if set(self.graph.predecessors(join_node_id)) != set(next_nodes): + raise NotImplementedError( + "The current pipeline does not support multiple levels of parallel nodes." + ) + inputs_for_join_node = {"inputs": []} + for n_id in next_nodes: + output = self.graph.nodes[n_id]["component"].run(**input_dict) + inputs_for_join_node["inputs"].append(output) + input_dict = inputs_for_join_node + current_node_id = join_node_id + elif len(next_nodes) == 1: + current_node_id = next_nodes[0] + else: + has_next_node = False + + return output_dict + + def _get_next_nodes(self, node_id: str, stream_id: str): + current_node_edges = self.graph.edges(node_id, data=True) + next_nodes = [ + next_node + for _, next_node, data in current_node_edges + if not stream_id or data["label"] == stream_id + ] + return next_nodes + + def draw(self, path: Path = Path("pipeline.png")): + """ + Create a Graphviz visualization of the pipeline. + + :param path: the path to save the image. + """ + try: + import pygraphviz + except ImportError: + raise ImportError(f"Could not import `pygraphviz`. Please install via: \n" + f"pip install pygraphviz\n" + f"(You might need to run this first: apt install libgraphviz-dev graphviz )") + + graphviz = to_agraph(self.graph) + graphviz.layout("dot") + graphviz.draw(path) + + +class ExtractiveQAPipeline: + def __init__(self, reader: BaseReader, retriever: BaseRetriever): + """ + Initialize a Pipeline for Extractive Question Answering. + + :param reader: Reader instance + :param retriever: Retriever instance + """ + self.pipeline = Pipeline() + self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) + self.pipeline.add_node(component=reader, name="Reader", inputs=["Retriever"]) + + def run(self, question, top_k_retriever=5, top_k_reader=5): + output = self.pipeline.run(question=question, + top_k_retriever=top_k_retriever, + top_k_reader=top_k_reader) + return output + + def add_node(self, component, name: str, inputs: List[str]): + self.pipeline.add_node(component=component, name=name, inputs=inputs) + + def draw(self, path: Path = Path("pipeline.png")): + self.pipeline.draw(path) + + +class DocumentSearchPipeline: + def __init__(self, retriever: BaseRetriever): + """ + Initialize a Pipeline for semantic document search. + + :param retriever: Retriever instance + """ + self.pipeline = Pipeline() + self.pipeline.add_node(component=retriever, name="Retriever", inputs=["Query"]) + + def run(self, question, top_k_retriever=5): + output = self.pipeline.run(question=question, top_k_retriever=top_k_retriever) + document_dicts = [doc.to_dict() for doc in output["documents"]] + output["documents"] = document_dicts + return output + + def add_node(self, component, name: str, inputs: List[str]): + self.pipeline.add_node(component=component, name=name, inputs=inputs) + + def draw(self, path: Path = Path("pipeline.png")): + self.pipeline.draw(path) + + +class QueryNode: + outgoing_edges = 1 + + def run(self, **kwargs): + return kwargs, "output_1" + + +class JoinDocuments: + outgoing_edges = 1 + + def __init__(self, join_mode="concatenate"): + pass + + def run(self, **kwargs): + inputs = kwargs["inputs"] + + documents = [] + for i, _ in inputs: + documents.extend(i["documents"]) + output = { + "question": inputs[0][0]["question"], + "documents": documents + } + return output, "output_1" diff --git a/haystack/reader/base.py b/haystack/reader/base.py index de914e22c..1b2d54382 100644 --- a/haystack/reader/base.py +++ b/haystack/reader/base.py @@ -1,6 +1,7 @@ import numpy as np from scipy.special import expit from abc import ABC, abstractmethod +from copy import deepcopy from typing import List, Optional, Sequence from haystack import Document @@ -8,6 +9,7 @@ from haystack import Document class BaseReader(ABC): return_no_answers: bool + outgoing_edges = 1 @abstractmethod def predict(self, question: str, documents: List[Document], top_k: Optional[int] = None): @@ -44,3 +46,18 @@ class BaseReader(ABC): "document_id": None, "meta": None,} return no_ans_prediction, max_no_ans_gap + + def run(self, question: str, documents: List[Document], top_k: Optional[int] = None): + if documents: + results = self.predict(question=question, documents=documents, top_k=top_k) + else: + results = {"answers": []} + + # Add corresponding document_name and more meta data, if an answer contains the document_id + for ans in results["answers"]: + ans["meta"] = {} + for doc in documents: + if doc.id == ans["document_id"]: + ans["meta"] = deepcopy(doc.meta) + + return results, "output_1" diff --git a/haystack/retriever/base.py b/haystack/retriever/base.py index 9936ee17c..3b87b2795 100644 --- a/haystack/retriever/base.py +++ b/haystack/retriever/base.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional import logging from time import perf_counter from functools import wraps @@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) class BaseRetriever(ABC): document_store: BaseDocumentStore + outgoing_edges = 1 @abstractmethod def retrieve(self, query: str, filters: dict = None, top_k: int = 10, index: str = None) -> List[Document]: @@ -164,4 +165,23 @@ class BaseRetriever(ABC): if return_preds: return {"metrics": metrics, "predictions": predictions} else: - return metrics \ No newline at end of file + return metrics + + def run( + self, + question: str, + filters: Optional[dict] = None, + top_k_retriever: Optional[int] = None, + top_k_reader: Optional[int] = None, + ): + if top_k_retriever: + documents = self.retrieve(query=question, filters=filters, top_k=top_k_retriever) + else: + documents = self.retrieve(query=question, filters=filters) + output = { + "question": question, + "documents": documents, + "top_k": top_k_reader + } + + return output, "output_1" diff --git a/haystack/schema.py b/haystack/schema.py index a70e1e23b..9f89caca5 100644 --- a/haystack/schema.py +++ b/haystack/schema.py @@ -71,6 +71,11 @@ class Document: return cls(**_new_doc) + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) class Label: def __init__(self, question: str, @@ -140,6 +145,11 @@ class Label: str(self.no_answer) + str(self.model_id)) + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) class MultiLabel: def __init__(self, question: str, @@ -181,4 +191,10 @@ class MultiLabel: return cls(**dict) def to_dict(self): - return self.__dict__ \ No newline at end of file + return self.__dict__ + + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 8f2449277..a279ab8a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ uvloop; sys_platform != 'win32' and sys_platform != 'cygwin' httptools nltk more_itertools +networkx \ No newline at end of file diff --git a/test/test_pipeline.py b/test/test_pipeline.py new file mode 100644 index 000000000..9e0aa1c3d --- /dev/null +++ b/test/test_pipeline.py @@ -0,0 +1,62 @@ +import pytest + +from haystack.pipeline import ExtractiveQAPipeline, Pipeline + + +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.parametrize("retriever_with_docs", ["elasticsearch"], indirect=True) +def test_graph_creation(reader, retriever_with_docs, document_store_with_docs): + pipeline = Pipeline() + pipeline.add_node(name="ES", component=retriever_with_docs, inputs=["Query"]) + + with pytest.raises(AssertionError): + pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.output_2"]) + + with pytest.raises(AssertionError): + pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["ES.wrong_edge_label"]) + + with pytest.raises(Exception): + pipeline.add_node(name="Reader", component=retriever_with_docs, inputs=["InvalidNode"]) + + +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) +def test_extractive_qa_answers(reader, retriever_with_docs, document_store_with_docs): + pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) + prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=3) + assert prediction is not None + assert prediction["question"] == "Who lives in Berlin?" + assert prediction["answers"][0]["answer"] == "Carla" + assert prediction["answers"][0]["probability"] <= 1 + assert prediction["answers"][0]["probability"] >= 0 + assert prediction["answers"][0]["meta"]["meta_field"] == "test1" + assert prediction["answers"][0]["context"] == "My name is Carla and I live in Berlin" + + assert len(prediction["answers"]) == 3 + + +@pytest.mark.elasticsearch +@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) +def test_extractive_qa_offsets(reader, retriever_with_docs, document_store_with_docs): + pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) + prediction = pipeline.run(question="Who lives in Berlin?", top_k_retriever=10, top_k_reader=5) + + assert prediction["answers"][0]["offset_start"] == 11 + assert prediction["answers"][0]["offset_end"] == 16 + start = prediction["answers"][0]["offset_start"] + end = prediction["answers"][0]["offset_end"] + assert prediction["answers"][0]["context"][start:end] == prediction["answers"][0]["answer"] + + +@pytest.mark.slow +@pytest.mark.elasticsearch +@pytest.mark.parametrize("retriever_with_docs", ["tfidf"], indirect=True) +def test_extractive_qa_answers_single_result(reader, retriever_with_docs, document_store_with_docs): + pipeline = ExtractiveQAPipeline(reader=reader, retriever=retriever_with_docs) + query = "testing finder" + prediction = pipeline.run(question=query, top_k_retriever=1, top_k_reader=1) + assert prediction is not None + assert len(prediction["answers"]) == 1 +