diff --git a/.github/labeler.yml b/.github/labeler.yml index 39565aa7b..8434604da 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -31,6 +31,9 @@ topic:reader: topic:retriever: - haystack/nodes/retriever/* - test/nodes/test_retriever.py +topic:agent: +- haystack/agents/* +- test/agents/* topic:pipeline: - haystack/pipelines/* - haystack/nodes/other/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 772b6b84e..542afabdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -411,6 +411,7 @@ jobs: folder: - "nodes" - "pipelines" + - "agents" - "modeling" - "others" @@ -455,6 +456,7 @@ jobs: folder: - "nodes" - "pipelines" + - "agents" - "modeling" #- "others" @@ -504,6 +506,7 @@ jobs: folder: - "nodes" - "pipelines" + - "agents" - "modeling" - "others" @@ -599,6 +602,7 @@ jobs: folder: - "nodes" - "pipelines" + - "agents" - "modeling" - "others" diff --git a/haystack/agents/__init__.py b/haystack/agents/__init__.py new file mode 100644 index 000000000..d3e7b91e6 --- /dev/null +++ b/haystack/agents/__init__.py @@ -0,0 +1,2 @@ +from haystack.agents.base import Agent +from haystack.agents.base import Tool diff --git a/haystack/agents/base.py b/haystack/agents/base.py new file mode 100644 index 000000000..aaef4accb --- /dev/null +++ b/haystack/agents/base.py @@ -0,0 +1,290 @@ +from __future__ import annotations + +import logging +import re +from typing import List, Optional, Union, Dict, Tuple, Any + +from haystack import Pipeline, BaseComponent, Answer +from haystack.errors import AgentError +from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate +from haystack.pipelines import ( + BaseStandardPipeline, + ExtractiveQAPipeline, + DocumentSearchPipeline, + GenerativeQAPipeline, + SearchSummarizationPipeline, + FAQPipeline, + TranslationWrapperPipeline, + RetrieverQuestionGenerationPipeline, +) +from haystack.telemetry import send_custom_event + +logger = logging.getLogger(__name__) + + +class Tool: + """ + A tool is a pipeline or node that also has a name and description. When you add a Tool to an Agent, the Agent can + invoke the underlying pipeline or node to answer questions. The wording of the description is important for the + Agent to decide which tool is most useful for a given question. + + :param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates. + The name should be short, ideally one token, and a good description of what the tool can do, for example + "Calculator" or "Search". + :param pipeline_or_node: The pipeline or node to run when this tool is invoked by an Agent. + :param description: A description of what the tool is useful for. An Agent can use this description for the decision + when to use which tool. For example, a tool for calculations can be described by "useful for when you need to + answer questions about math". + """ + + def __init__( + self, + name: str, + pipeline_or_node: Union[ + BaseComponent, + Pipeline, + ExtractiveQAPipeline, + DocumentSearchPipeline, + GenerativeQAPipeline, + SearchSummarizationPipeline, + FAQPipeline, + TranslationWrapperPipeline, + RetrieverQuestionGenerationPipeline, + ], + description: str, + ): + self.name = name + self.pipeline_or_node = pipeline_or_node + self.description = description + + +class Agent: + """ + An Agent answers queries by choosing between different tools, which are pipelines or nodes. It uses a large + language model (LLM) to generate a thought based on the query, choose a tool, and generate the input for the + tool. Based on the result returned by the tool the Agent can either stop if it now knows the answer or repeat the + process of 1) generate thought, 2) choose tool, 3) generate input. + + Agents are useful for questions containing multiple subquestions that can be answered step-by-step (Multihop QA) + using multiple pipelines and nodes as tools. + """ + + def __init__( + self, + prompt_node: PromptNode, + prompt_template: Union[str, PromptTemplate] = "zero-shot-react", + tools: Optional[List[Tool]] = None, + max_iterations: int = 5, + tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*', + final_answer_pattern: str = r"Final Answer:\s*(\w+)\s*", + ): + """ + Creates an Agent instance. + + :param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration. + :param prompt_template: The name of a PromptTemplate supported by the PromptNode or a new PromptTemplate. It is used for generating thoughts and running Tools to answer queries step-by-step. + :param tools: A List of Tools that the Agent can choose to run. If no Tools are provided, they need to be added with `add_tool()` before you can use the Agent. + :param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer. + Set at least to 2 so that the Agent can run one Tool and then infer it knows the final answer. Default 5. + :param tool_pattern: A regular expression to extract the name of the Tool and the corresponding input from the text generated by the Agent. + :param final_answer_pattern: A regular expression to extract final answer from the text generated by the Agent. + """ + self.prompt_node = prompt_node + self.prompt_template = ( + prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template + ) + self.tools = {tool.name: tool for tool in tools} if tools else {} + self.tool_names = ", ".join(self.tools.keys()) + self.tool_names_with_descriptions = "\n".join( + [f"{tool.name}: {tool.description}" for tool in self.tools.values()] + ) + self.max_iterations = max_iterations + self.tool_pattern = tool_pattern + self.final_answer_pattern = final_answer_pattern + send_custom_event(event=f"{type(self).__name__} initialized") + + def add_tool(self, tool: Tool): + """ + Add the provided tool to the Agent and update the template for the Agent's PromptNode. + + :param tool: The Tool to add to the Agent. Any previously added tool with the same name will be overwritten. + """ + self.tools[tool.name] = tool + self.tool_names = ", ".join(self.tools.keys()) + self.tool_names_with_descriptions = "\n".join( + [f"{tool.name}: {tool.description}" for tool in self.tools.values()] + ) + + def has_tool(self, tool_name: str): + """ + Check whether the Agent has a Tool registered under the provided tool name. + + :param tool_name: The name of the Tool for which to check whether the Agent has it. + """ + return tool_name in self.tools + + def run( + self, query: str, max_iterations: Optional[int] = None, params: Optional[dict] = None + ) -> Dict[str, Union[str, List[Answer]]]: + """ + Runs the Agent given a query and optional parameters to pass on to the tools used. The result is in the + same format as a pipeline's result: a dictionary with a key `answers` containing a List of Answers + + :param query: The search query. + :param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer. + If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer. + :param params: A dictionary of parameters that you want to pass to those tools that are pipelines. + To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`. + To pass a parameter to targeted nodes in those pipelines, use: + `{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`. + Parameters can only be passed to tools that are pipelines but not nodes. + """ + if not self.tools: + raise AgentError( + "Agents without tools cannot be run. Add at least one tool using `add_tool()` or set the parameter `tools` when initializing the Agent." + ) + if max_iterations is None: + max_iterations = self.max_iterations + if max_iterations < 2: + raise AgentError( + f"max_iterations was set to {max_iterations} but it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer." + ) + + transcript = self._get_initial_transcript(query=query) + # Generate a thought with a plan what to do, choose a tool, generate input for it, and run it. + # Repeat this until the final answer is found or the maximum number of iterations is reached. + for _ in range(max_iterations): + preds = self.prompt_node(transcript) + if not preds: + raise AgentError(f"No output was generated by the Agent. Transcript:\n{transcript}") + + # Try to extract final answer or tool name and input from the generated text and update the transcript + final_answer = self._extract_final_answer(pred=preds[0]) + if final_answer is not None: + transcript += preds[0] + return self._format_answer(query=query, transcript=transcript, answer=final_answer) + tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0]) + if tool_name is None or tool_input is None: + raise AgentError(f"Wrong output format. Transcript:\n{transcript}") + + result = self._run_tool(tool_name=tool_name, tool_input=tool_input, transcript=transcript, params=params) + observation = self._extract_observation(result) + transcript += f"{preds[0]}\nObservation: {observation}\nThought: Now that I know that {observation} is the answer to {tool_name} {tool_input}, I " + + logger.warning( + "Maximum number of iterations (%s) reached for query (%s). Increase max_iterations " + "or no answer can be provided for this query.", + max_iterations, + query, + ) + return self._format_answer(query=query, transcript=transcript, answer="") + + def run_batch( + self, queries: List[str], max_iterations: Optional[int] = None, params: Optional[dict] = None + ) -> Dict[str, str]: + """ + Runs the Agent in a batch mode + + :param queries: List of search queries. + :param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer. + If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer. + :param params: A dictionary of parameters that you want to pass to those tools that are pipelines. + To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`. + To pass a parameter to targeted nodes in those pipelines, use: + `{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`. + Parameters can only be passed to tools that are pipelines but not nodes. + """ + results: Dict = {"queries": [], "answers": [], "transcripts": []} + for query in queries: + result = self.run(query=query, max_iterations=max_iterations, params=params) + results["queries"].append(result["query"]) + results["answers"].append(result["answers"]) + results["transcripts"].append(result["transcript"]) + + return results + + def _run_tool( + self, tool_name: str, tool_input: str, transcript: str, params: Optional[dict] = None + ) -> Union[Tuple[Dict[str, Any], str], Dict[str, Any]]: + if not self.has_tool(tool_name): + raise AgentError( + f"Cannot use the tool {tool_name} because it is not in the list of added tools {self.tools.keys()}." + "Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent." + f"Transcript:\n{transcript}" + ) + + pipeline_or_node = self.tools[tool_name].pipeline_or_node + # We can only pass params to pipelines but not to nodes + if isinstance(pipeline_or_node, (Pipeline, BaseStandardPipeline)): + result = pipeline_or_node.run(query=tool_input, params=params) + elif isinstance(pipeline_or_node, BaseRetriever): + result = pipeline_or_node.run(query=tool_input, root_node="Query") + else: + result = pipeline_or_node.run(query=tool_input) + return result + + def _extract_observation(self, result: Union[Tuple[Dict[str, Any], str], Dict[str, Any]]) -> str: + observation = "" + # if result was returned by a node it is of type tuple. We use only the output but not the name of the output. + # if result was returned by a pipeline it is of type dict that we can use directly. + if isinstance(result, tuple): + result = result[0] + if isinstance(result, dict): + if result.get("results", None): + observation = result["results"][0] + elif result.get("answers", None): + observation = result["answers"][0].answer + elif result.get("documents", None): + observation = result["documents"][0].content + + # observation remains "" if no result/answer/document was returned + return observation + + def _extract_tool_name_and_tool_input(self, pred: str) -> Tuple[Optional[str], Optional[str]]: + """ + Parse the tool name and the tool input from the prediction output of the Agent's PromptNode. + + :param pred: Prediction output of the Agent's PromptNode from which to parse the tool and tool input + """ + tool_match = re.search(self.tool_pattern, pred) + if tool_match: + tool_name = tool_match.group(1) + tool_input = tool_match.group(3) + return tool_name.strip('" []').strip(), tool_input.strip('" ') + return None, None + + def _extract_final_answer(self, pred: str) -> Optional[str]: + """ + Parse the final answer from the prediction output of the Agent's PromptNode. + + :param pred: Prediction output of the Agent's PromptNode from which to parse the final answer + """ + final_answer_match = re.search(self.final_answer_pattern, pred) + if final_answer_match: + final_answer = final_answer_match.group(1) + return final_answer.strip('" ') + return None + + def _format_answer(self, query: str, answer: str, transcript: str) -> Dict[str, Union[str, List[Answer]]]: + """ + Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline. + The full transcript based on the Agent's initial prompt template and the text it generated during execution. + + :param query: The search query. + :param answer: The final answer returned by the Agent. An empty string corresponds to no answer. + :param transcript: The text generated by the Agent and the initial filled template for debug purposes. + """ + return {"query": query, "answers": [Answer(answer=answer, type="generative")], "transcript": transcript} + + def _get_initial_transcript(self, query: str): + """ + Fills the Agent's PromptTemplate with the query, tool names and descriptions + + :param query: The search query. + """ + return next( + self.prompt_template.fill( + query=query, tool_names=self.tool_names, tool_names_with_descriptions=self.tool_names_with_descriptions + ), + "", + ) diff --git a/haystack/errors.py b/haystack/errors.py index 59d134888..ab17d3fd7 100644 --- a/haystack/errors.py +++ b/haystack/errors.py @@ -47,6 +47,15 @@ class ModelingError(HaystackError): super().__init__(message=message, docs_link=docs_link) +class AgentError(HaystackError): + """Exception for issues raised within an agent""" + + def __init__( + self, message: Optional[str] = None, docs_link: Optional[str] = "https://docs.haystack.deepset.ai/docs/agents" + ): + super().__init__(message=message, docs_link=docs_link) + + class PipelineError(HaystackError): """Exception for issues raised within a pipeline""" diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 01e2d2d11..84852630e 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -689,6 +689,27 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: name="translation", prompt_text="Translate the following context to $target_language. Context: $documents; Translation:", ), + PromptTemplate( + name="zero-shot-react", + prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions " + "correctly, you have access to the following tools:\n\n" + "$tool_names_with_descriptions\n\n" + "To answer questions, you'll need to go through multiple steps involving step-by-step thinking and " + "selecting appropriate tools and their inputs; tools will respond with observations. When you are ready " + "for a final answer, respond with the `Final Answer:`\n\n" + "Use the following format:\n\n" + "Question: the question to be answered\n" + "Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n" + "Tool: [$tool_names]\n" + "Tool Input: the input for the tool\n" + "Observation: the tool will respond with the result\n" + "...\n" + "Final Answer: the final answer to the question, make it short (1-5 words)\n\n" + "Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n" + "---\n\n" + "Question: $query\n" + "Thought: Let's think step-by-step, I first need to ", + ), ] diff --git a/test/agents/__init__.py b/test/agents/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py new file mode 100644 index 000000000..878d29acf --- /dev/null +++ b/test/agents/test_agent.py @@ -0,0 +1,251 @@ +import logging +import os +from typing import Tuple + +import pytest + +from haystack import BaseComponent, Answer +from haystack.agents import Agent +from haystack.agents.base import Tool +from haystack.errors import AgentError +from haystack.nodes import PromptModel, PromptNode, PromptTemplate +from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline +from test.conftest import MockRetriever, MockPromptNode + + +def test_add_and_overwrite_tool(): + # Add a Node as a Tool to an Agent + agent = Agent(prompt_node=MockPromptNode()) + retriever = MockRetriever() + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever, + description="useful for when you need to " "retrieve documents from your index", + ) + ) + assert len(agent.tools) == 1 + assert "Retriever" in agent.tools + assert agent.has_tool(tool_name="Retriever") + assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseComponent) + + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever, + description="useful for when you need to retrieve documents from your index", + ) + ) + + # Add a Pipeline as a Tool to an Agent and overwrite the previously added Tool + retriever_pipeline = DocumentSearchPipeline(MockRetriever()) + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever_pipeline, + description="useful for when you need to retrieve documents from your index", + ) + ) + assert len(agent.tools) == 1 + assert "Retriever" in agent.tools + assert agent.has_tool(tool_name="Retriever") + assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseStandardPipeline) + + +def test_agent_chooses_no_action(): + agent = Agent(prompt_node=MockPromptNode()) + retriever = MockRetriever() + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever, + description="useful for when you need to retrieve documents from your index", + ) + ) + with pytest.raises(AgentError, match=r"Wrong output format.*"): + agent.run("How many letters does the name of the town where Christelle lives have?") + + +def test_max_iterations(caplog, monkeypatch): + # Run an Agent and stop because max_iterations is reached + agent = Agent(prompt_node=MockPromptNode(), max_iterations=3) + retriever = MockRetriever() + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever, + description="useful for when you need to retrieve documents from your index", + ) + ) + + # Let the Agent always choose "Retriever" as the Tool with "" as input + def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]: + return "Retriever", "" + + monkeypatch.setattr(Agent, "_extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input) + + # Using max_iterations as specified in the Agent's init method + with caplog.at_level(logging.WARN, logger="haystack.agents"): + result = agent.run("Where does Christelle live?") + assert result["answers"] == [Answer(answer="", type="generative")] + assert "Maximum number of iterations (3) reached" in caplog.text + + # Setting max_iterations in the Agent's run method + with caplog.at_level(logging.WARN, logger="haystack.agents"): + result = agent.run("Where does Christelle live?", max_iterations=2) + assert result["answers"] == [Answer(answer="", type="generative")] + assert "Maximum number of iterations (2) reached" in caplog.text + + +def test_run_tool(): + agent = Agent(prompt_node=MockPromptNode()) + retriever = MockRetriever() + agent.add_tool( + Tool( + name="Retriever", + pipeline_or_node=retriever, + description="useful for when you need to retrieve documents from your index", + ) + ) + result = agent._run_tool(tool_name="Retriever", tool_input="", transcript="") + assert result[0]["documents"] == [] + + +def test_extract_observation(): + agent = Agent(prompt_node=MockPromptNode()) + observation = agent._extract_observation( + result={ + "answers": [ + Answer(answer="first answer", type="generative"), + Answer(answer="second answer", type="generative"), + ] + } + ) + assert observation == "first answer" + + +def test_extract_tool_name_and_tool_input(): + agent = Agent(prompt_node=MockPromptNode()) + + pred = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born" + tool_name, tool_input = agent._extract_tool_name_and_tool_input(pred) + assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born" + + +def test_extract_final_answer(): + agent = Agent(prompt_node=MockPromptNode()) + + pred = "have the final answer to the question.\nFinal Answer: Florida" + final_answer = agent._extract_final_answer(pred) + assert final_answer == "Florida" + + +def test_format_answer(): + agent = Agent(prompt_node=MockPromptNode()) + formatted_answer = agent._format_answer(query="query", answer="answer", transcript="transcript") + assert formatted_answer["query"] == "query" + assert formatted_answer["answers"] == [Answer(answer="answer", type="generative")] + assert formatted_answer["transcript"] == "transcript" + + +@pytest.mark.integration +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True) +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_agent_run(reader, retriever_with_docs, document_store_with_docs): + search = ExtractiveQAPipeline(reader, retriever_with_docs) + prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY")) + prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"]) + counter = PromptNode( + model_name_or_path=prompt_model, + default_prompt_template=PromptTemplate( + name="calculator_template", + prompt_text="When I give you a word, respond with the number of characters that this word contains.\n" + "Word: Rome\nLength: 4\n" + "Word: Arles\nLength: 5\n" + "Word: Berlin\nLength: 6\n" + "Word: $query?\nLength: ", + prompt_params=["query"], + ), + ) + + agent = Agent(prompt_node=prompt_node) + agent.add_tool( + Tool( + name="Search", + pipeline_or_node=search, + description="useful for when you need to answer " + "questions about where people live. You " + "should ask targeted questions", + ) + ) + agent.add_tool( + Tool( + name="Count", + pipeline_or_node=counter, + description="useful for when you need to count how many characters are in a word. Ask only with a single word.", + ) + ) + + # TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors + result = agent.run("How many characters are in the word Madrid?") + assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"]) + + result = agent.run("How many letters does the name of the town where Christelle lives have?") + assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"]) + + +@pytest.mark.integration +@pytest.mark.parametrize("reader", ["farm"], indirect=True) +@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True) +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY", None), + reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.", +) +def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs): + search = ExtractiveQAPipeline(reader, retriever_with_docs) + prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY")) + prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"]) + counter = PromptNode( + model_name_or_path=prompt_model, + default_prompt_template=PromptTemplate( + name="calculator_template", + prompt_text="When I give you a word, respond with the number of characters that this word contains.\n" + "Word: Rome\nLength: 4\n" + "Word: Arles\nLength: 5\n" + "Word: Berlin\nLength: 6\n" + "Word: $query?\nLength: ", + prompt_params=["query"], + ), + ) + + agent = Agent(prompt_node=prompt_node) + agent.add_tool( + Tool( + name="Search", + pipeline_or_node=search, + description="useful for when you need to answer " + "questions about where people live. You " + "should ask targeted questions", + ) + ) + agent.add_tool( + Tool( + name="Count", + pipeline_or_node=counter, + description="useful for when you need to count how many characters are in a word. Ask only with a single word.", + ) + ) + + results = agent.run_batch( + queries=[ + "How many characters are in the word Madrid?", + "How many letters does the name of the town where Christelle lives have?", + ] + ) + # TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors + assert any(digit in results["answers"][0][0].answer for digit in ["5", "6", "five", "six"]) + assert any(digit in results["answers"][1][0].answer for digit in ["5", "6", "five", "six"]) diff --git a/test/conftest.py b/test/conftest.py index 0576d3c34..c566caa13 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -55,10 +55,11 @@ from haystack.nodes import ( TransformersSummarizer, TransformersTranslator, QuestionGenerator, + PromptTemplate, ) from haystack.modeling.infer import Inferencer, QAInferencer from haystack.nodes.prompt import PromptNode, PromptModel -from haystack.schema import Document +from haystack.schema import Document, FilterType from haystack.utils.import_utils import _optional_component_not_installed try: @@ -272,11 +273,30 @@ class MockDocumentStore(BaseDocumentStore): class MockRetriever(BaseRetriever): outgoing_edges = 1 - def retrieve(self, query: str, top_k: int): - pass + def retrieve( + self, + query: str, + filters: Optional[FilterType] = None, + top_k: Optional[int] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + scale_score: Optional[bool] = None, + document_store: Optional[BaseDocumentStore] = None, + ) -> List[Document]: + return [] - def retrieve_batch(self, queries: List[str], top_k: int): - pass + def retrieve_batch( + self, + queries: List[str], + filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None, + top_k: Optional[int] = None, + index: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + batch_size: Optional[int] = None, + scale_score: Optional[bool] = None, + document_store: Optional[BaseDocumentStore] = None, + ) -> List[List[Document]]: + return [[]] class MockSeq2SegGenerator(BaseGenerator): @@ -343,6 +363,40 @@ class MockReader(BaseReader): pass +class MockPromptNode(PromptNode): + def __init__(self): + self.default_prompt_template = None + + def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]: + return [""] + + def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate: + if prompt_template_name == "think-step-by-step": + return PromptTemplate( + name="think-step-by-step", + prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions " + "correctly, you have access to the following tools:\n\n" + "$tool_names_with_descriptions\n\n" + "To answer questions, you'll need to go through multiple steps involving step-by-step thinking and " + "selecting appropriate tools and their inputs; tools will respond with observations. When you are ready " + "for a final answer, respond with the `Final Answer:`\n\n" + "Use the following format:\n\n" + "Question: the question to be answered\n" + "Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n" + "Tool: [$tool_names]\n" + "Tool Input: the input for the tool\n" + "Observation: the tool will respond with the result\n" + "...\n" + "Final Answer: the final answer to the question, make it short (1-5 words)\n\n" + "Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n" + "---\n\n" + "Question: $query\n" + "Thought: Let's think step-by-step, I first need to $generated_text", + ) + else: + return PromptTemplate(name="", prompt_text="") + + # # Document collections #