feat: Add Agent (#4148)

* initial Agent implementation

* mypy and pylint fixes

* add missing ABC import

* improved prompt template

* refactor and shorten run method

* refactor and shorten run method

* add tests for extracting

* fix mixed up tool_input/observation & make tests more robust

* fix bug with max_iterations and update prompt template

* allow setting prompt_template in Agent init

* remove example yml for agent

* add final prediction to transcript

* add transcript to errors and accept PromptTemplate in init

* simplify if else to elif

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>

* add checks for max_iter<2 and empty list returned by prompt node

---------

Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
Julian Risch 2023-02-21 14:27:40 +01:00 committed by GitHub
parent bde01cbf1f
commit 5ce7a404ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 639 additions and 5 deletions

3
.github/labeler.yml vendored
View File

@ -31,6 +31,9 @@ topic:reader:
topic:retriever:
- haystack/nodes/retriever/*
- test/nodes/test_retriever.py
topic:agent:
- haystack/agents/*
- test/agents/*
topic:pipeline:
- haystack/pipelines/*
- haystack/nodes/other/*

View File

@ -411,6 +411,7 @@ jobs:
folder:
- "nodes"
- "pipelines"
- "agents"
- "modeling"
- "others"
@ -455,6 +456,7 @@ jobs:
folder:
- "nodes"
- "pipelines"
- "agents"
- "modeling"
#- "others"
@ -504,6 +506,7 @@ jobs:
folder:
- "nodes"
- "pipelines"
- "agents"
- "modeling"
- "others"
@ -599,6 +602,7 @@ jobs:
folder:
- "nodes"
- "pipelines"
- "agents"
- "modeling"
- "others"

View File

@ -0,0 +1,2 @@
from haystack.agents.base import Agent
from haystack.agents.base import Tool

290
haystack/agents/base.py Normal file
View File

@ -0,0 +1,290 @@
from __future__ import annotations
import logging
import re
from typing import List, Optional, Union, Dict, Tuple, Any
from haystack import Pipeline, BaseComponent, Answer
from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.pipelines import (
BaseStandardPipeline,
ExtractiveQAPipeline,
DocumentSearchPipeline,
GenerativeQAPipeline,
SearchSummarizationPipeline,
FAQPipeline,
TranslationWrapperPipeline,
RetrieverQuestionGenerationPipeline,
)
from haystack.telemetry import send_custom_event
logger = logging.getLogger(__name__)
class Tool:
"""
A tool is a pipeline or node that also has a name and description. When you add a Tool to an Agent, the Agent can
invoke the underlying pipeline or node to answer questions. The wording of the description is important for the
Agent to decide which tool is most useful for a given question.
:param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates.
The name should be short, ideally one token, and a good description of what the tool can do, for example
"Calculator" or "Search".
:param pipeline_or_node: The pipeline or node to run when this tool is invoked by an Agent.
:param description: A description of what the tool is useful for. An Agent can use this description for the decision
when to use which tool. For example, a tool for calculations can be described by "useful for when you need to
answer questions about math".
"""
def __init__(
self,
name: str,
pipeline_or_node: Union[
BaseComponent,
Pipeline,
ExtractiveQAPipeline,
DocumentSearchPipeline,
GenerativeQAPipeline,
SearchSummarizationPipeline,
FAQPipeline,
TranslationWrapperPipeline,
RetrieverQuestionGenerationPipeline,
],
description: str,
):
self.name = name
self.pipeline_or_node = pipeline_or_node
self.description = description
class Agent:
"""
An Agent answers queries by choosing between different tools, which are pipelines or nodes. It uses a large
language model (LLM) to generate a thought based on the query, choose a tool, and generate the input for the
tool. Based on the result returned by the tool the Agent can either stop if it now knows the answer or repeat the
process of 1) generate thought, 2) choose tool, 3) generate input.
Agents are useful for questions containing multiple subquestions that can be answered step-by-step (Multihop QA)
using multiple pipelines and nodes as tools.
"""
def __init__(
self,
prompt_node: PromptNode,
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
tools: Optional[List[Tool]] = None,
max_iterations: int = 5,
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
final_answer_pattern: str = r"Final Answer:\s*(\w+)\s*",
):
"""
Creates an Agent instance.
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration.
:param prompt_template: The name of a PromptTemplate supported by the PromptNode or a new PromptTemplate. It is used for generating thoughts and running Tools to answer queries step-by-step.
:param tools: A List of Tools that the Agent can choose to run. If no Tools are provided, they need to be added with `add_tool()` before you can use the Agent.
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
Set at least to 2 so that the Agent can run one Tool and then infer it knows the final answer. Default 5.
:param tool_pattern: A regular expression to extract the name of the Tool and the corresponding input from the text generated by the Agent.
:param final_answer_pattern: A regular expression to extract final answer from the text generated by the Agent.
"""
self.prompt_node = prompt_node
self.prompt_template = (
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
)
self.tools = {tool.name: tool for tool in tools} if tools else {}
self.tool_names = ", ".join(self.tools.keys())
self.tool_names_with_descriptions = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
self.max_iterations = max_iterations
self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern
send_custom_event(event=f"{type(self).__name__} initialized")
def add_tool(self, tool: Tool):
"""
Add the provided tool to the Agent and update the template for the Agent's PromptNode.
:param tool: The Tool to add to the Agent. Any previously added tool with the same name will be overwritten.
"""
self.tools[tool.name] = tool
self.tool_names = ", ".join(self.tools.keys())
self.tool_names_with_descriptions = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
def has_tool(self, tool_name: str):
"""
Check whether the Agent has a Tool registered under the provided tool name.
:param tool_name: The name of the Tool for which to check whether the Agent has it.
"""
return tool_name in self.tools
def run(
self, query: str, max_iterations: Optional[int] = None, params: Optional[dict] = None
) -> Dict[str, Union[str, List[Answer]]]:
"""
Runs the Agent given a query and optional parameters to pass on to the tools used. The result is in the
same format as a pipeline's result: a dictionary with a key `answers` containing a List of Answers
:param query: The search query.
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer.
:param params: A dictionary of parameters that you want to pass to those tools that are pipelines.
To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`.
To pass a parameter to targeted nodes in those pipelines, use:
`{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`.
Parameters can only be passed to tools that are pipelines but not nodes.
"""
if not self.tools:
raise AgentError(
"Agents without tools cannot be run. Add at least one tool using `add_tool()` or set the parameter `tools` when initializing the Agent."
)
if max_iterations is None:
max_iterations = self.max_iterations
if max_iterations < 2:
raise AgentError(
f"max_iterations was set to {max_iterations} but it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer."
)
transcript = self._get_initial_transcript(query=query)
# Generate a thought with a plan what to do, choose a tool, generate input for it, and run it.
# Repeat this until the final answer is found or the maximum number of iterations is reached.
for _ in range(max_iterations):
preds = self.prompt_node(transcript)
if not preds:
raise AgentError(f"No output was generated by the Agent. Transcript:\n{transcript}")
# Try to extract final answer or tool name and input from the generated text and update the transcript
final_answer = self._extract_final_answer(pred=preds[0])
if final_answer is not None:
transcript += preds[0]
return self._format_answer(query=query, transcript=transcript, answer=final_answer)
tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0])
if tool_name is None or tool_input is None:
raise AgentError(f"Wrong output format. Transcript:\n{transcript}")
result = self._run_tool(tool_name=tool_name, tool_input=tool_input, transcript=transcript, params=params)
observation = self._extract_observation(result)
transcript += f"{preds[0]}\nObservation: {observation}\nThought: Now that I know that {observation} is the answer to {tool_name} {tool_input}, I "
logger.warning(
"Maximum number of iterations (%s) reached for query (%s). Increase max_iterations "
"or no answer can be provided for this query.",
max_iterations,
query,
)
return self._format_answer(query=query, transcript=transcript, answer="")
def run_batch(
self, queries: List[str], max_iterations: Optional[int] = None, params: Optional[dict] = None
) -> Dict[str, str]:
"""
Runs the Agent in a batch mode
:param queries: List of search queries.
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer.
:param params: A dictionary of parameters that you want to pass to those tools that are pipelines.
To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`.
To pass a parameter to targeted nodes in those pipelines, use:
`{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`.
Parameters can only be passed to tools that are pipelines but not nodes.
"""
results: Dict = {"queries": [], "answers": [], "transcripts": []}
for query in queries:
result = self.run(query=query, max_iterations=max_iterations, params=params)
results["queries"].append(result["query"])
results["answers"].append(result["answers"])
results["transcripts"].append(result["transcript"])
return results
def _run_tool(
self, tool_name: str, tool_input: str, transcript: str, params: Optional[dict] = None
) -> Union[Tuple[Dict[str, Any], str], Dict[str, Any]]:
if not self.has_tool(tool_name):
raise AgentError(
f"Cannot use the tool {tool_name} because it is not in the list of added tools {self.tools.keys()}."
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
f"Transcript:\n{transcript}"
)
pipeline_or_node = self.tools[tool_name].pipeline_or_node
# We can only pass params to pipelines but not to nodes
if isinstance(pipeline_or_node, (Pipeline, BaseStandardPipeline)):
result = pipeline_or_node.run(query=tool_input, params=params)
elif isinstance(pipeline_or_node, BaseRetriever):
result = pipeline_or_node.run(query=tool_input, root_node="Query")
else:
result = pipeline_or_node.run(query=tool_input)
return result
def _extract_observation(self, result: Union[Tuple[Dict[str, Any], str], Dict[str, Any]]) -> str:
observation = ""
# if result was returned by a node it is of type tuple. We use only the output but not the name of the output.
# if result was returned by a pipeline it is of type dict that we can use directly.
if isinstance(result, tuple):
result = result[0]
if isinstance(result, dict):
if result.get("results", None):
observation = result["results"][0]
elif result.get("answers", None):
observation = result["answers"][0].answer
elif result.get("documents", None):
observation = result["documents"][0].content
# observation remains "" if no result/answer/document was returned
return observation
def _extract_tool_name_and_tool_input(self, pred: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the prediction output of the Agent's PromptNode.
:param pred: Prediction output of the Agent's PromptNode from which to parse the tool and tool input
"""
tool_match = re.search(self.tool_pattern, pred)
if tool_match:
tool_name = tool_match.group(1)
tool_input = tool_match.group(3)
return tool_name.strip('" []').strip(), tool_input.strip('" ')
return None, None
def _extract_final_answer(self, pred: str) -> Optional[str]:
"""
Parse the final answer from the prediction output of the Agent's PromptNode.
:param pred: Prediction output of the Agent's PromptNode from which to parse the final answer
"""
final_answer_match = re.search(self.final_answer_pattern, pred)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ')
return None
def _format_answer(self, query: str, answer: str, transcript: str) -> Dict[str, Union[str, List[Answer]]]:
"""
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
:param query: The search query.
:param answer: The final answer returned by the Agent. An empty string corresponds to no answer.
:param transcript: The text generated by the Agent and the initial filled template for debug purposes.
"""
return {"query": query, "answers": [Answer(answer=answer, type="generative")], "transcript": transcript}
def _get_initial_transcript(self, query: str):
"""
Fills the Agent's PromptTemplate with the query, tool names and descriptions
:param query: The search query.
"""
return next(
self.prompt_template.fill(
query=query, tool_names=self.tool_names, tool_names_with_descriptions=self.tool_names_with_descriptions
),
"",
)

View File

@ -47,6 +47,15 @@ class ModelingError(HaystackError):
super().__init__(message=message, docs_link=docs_link)
class AgentError(HaystackError):
"""Exception for issues raised within an agent"""
def __init__(
self, message: Optional[str] = None, docs_link: Optional[str] = "https://docs.haystack.deepset.ai/docs/agents"
):
super().__init__(message=message, docs_link=docs_link)
class PipelineError(HaystackError):
"""Exception for issues raised within a pipeline"""

View File

@ -689,6 +689,27 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
name="translation",
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
),
PromptTemplate(
name="zero-shot-react",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"$tool_names_with_descriptions\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: [$tool_names]\n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: $query\n"
"Thought: Let's think step-by-step, I first need to ",
),
]

0
test/agents/__init__.py Normal file
View File

251
test/agents/test_agent.py Normal file
View File

@ -0,0 +1,251 @@
import logging
import os
from typing import Tuple
import pytest
from haystack import BaseComponent, Answer
from haystack.agents import Agent
from haystack.agents.base import Tool
from haystack.errors import AgentError
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
from test.conftest import MockRetriever, MockPromptNode
def test_add_and_overwrite_tool():
# Add a Node as a Tool to an Agent
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to " "retrieve documents from your index",
)
)
assert len(agent.tools) == 1
assert "Retriever" in agent.tools
assert agent.has_tool(tool_name="Retriever")
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseComponent)
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
# Add a Pipeline as a Tool to an Agent and overwrite the previously added Tool
retriever_pipeline = DocumentSearchPipeline(MockRetriever())
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever_pipeline,
description="useful for when you need to retrieve documents from your index",
)
)
assert len(agent.tools) == 1
assert "Retriever" in agent.tools
assert agent.has_tool(tool_name="Retriever")
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseStandardPipeline)
def test_agent_chooses_no_action():
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
with pytest.raises(AgentError, match=r"Wrong output format.*"):
agent.run("How many letters does the name of the town where Christelle lives have?")
def test_max_iterations(caplog, monkeypatch):
# Run an Agent and stop because max_iterations is reached
agent = Agent(prompt_node=MockPromptNode(), max_iterations=3)
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
# Let the Agent always choose "Retriever" as the Tool with "" as input
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
return "Retriever", ""
monkeypatch.setattr(Agent, "_extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
# Using max_iterations as specified in the Agent's init method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
result = agent.run("Where does Christelle live?")
assert result["answers"] == [Answer(answer="", type="generative")]
assert "Maximum number of iterations (3) reached" in caplog.text
# Setting max_iterations in the Agent's run method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
result = agent.run("Where does Christelle live?", max_iterations=2)
assert result["answers"] == [Answer(answer="", type="generative")]
assert "Maximum number of iterations (2) reached" in caplog.text
def test_run_tool():
agent = Agent(prompt_node=MockPromptNode())
retriever = MockRetriever()
agent.add_tool(
Tool(
name="Retriever",
pipeline_or_node=retriever,
description="useful for when you need to retrieve documents from your index",
)
)
result = agent._run_tool(tool_name="Retriever", tool_input="", transcript="")
assert result[0]["documents"] == []
def test_extract_observation():
agent = Agent(prompt_node=MockPromptNode())
observation = agent._extract_observation(
result={
"answers": [
Answer(answer="first answer", type="generative"),
Answer(answer="second answer", type="generative"),
]
}
)
assert observation == "first answer"
def test_extract_tool_name_and_tool_input():
agent = Agent(prompt_node=MockPromptNode())
pred = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
tool_name, tool_input = agent._extract_tool_name_and_tool_input(pred)
assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born"
def test_extract_final_answer():
agent = Agent(prompt_node=MockPromptNode())
pred = "have the final answer to the question.\nFinal Answer: Florida"
final_answer = agent._extract_final_answer(pred)
assert final_answer == "Florida"
def test_format_answer():
agent = Agent(prompt_node=MockPromptNode())
formatted_answer = agent._format_answer(query="query", answer="answer", transcript="transcript")
assert formatted_answer["query"] == "query"
assert formatted_answer["answers"] == [Answer(answer="answer", type="generative")]
assert formatted_answer["transcript"] == "transcript"
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
search = ExtractiveQAPipeline(reader, retriever_with_docs)
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
counter = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
name="calculator_template",
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
"Word: $query?\nLength: ",
prompt_params=["query"],
),
)
agent = Agent(prompt_node=prompt_node)
agent.add_tool(
Tool(
name="Search",
pipeline_or_node=search,
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
)
)
agent.add_tool(
Tool(
name="Count",
pipeline_or_node=counter,
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
)
)
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
result = agent.run("How many characters are in the word Madrid?")
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
result = agent.run("How many letters does the name of the town where Christelle lives have?")
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
@pytest.mark.integration
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
@pytest.mark.skipif(
not os.environ.get("OPENAI_API_KEY", None),
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
)
def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
search = ExtractiveQAPipeline(reader, retriever_with_docs)
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
counter = PromptNode(
model_name_or_path=prompt_model,
default_prompt_template=PromptTemplate(
name="calculator_template",
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
"Word: Rome\nLength: 4\n"
"Word: Arles\nLength: 5\n"
"Word: Berlin\nLength: 6\n"
"Word: $query?\nLength: ",
prompt_params=["query"],
),
)
agent = Agent(prompt_node=prompt_node)
agent.add_tool(
Tool(
name="Search",
pipeline_or_node=search,
description="useful for when you need to answer "
"questions about where people live. You "
"should ask targeted questions",
)
)
agent.add_tool(
Tool(
name="Count",
pipeline_or_node=counter,
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
)
)
results = agent.run_batch(
queries=[
"How many characters are in the word Madrid?",
"How many letters does the name of the town where Christelle lives have?",
]
)
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
assert any(digit in results["answers"][0][0].answer for digit in ["5", "6", "five", "six"])
assert any(digit in results["answers"][1][0].answer for digit in ["5", "6", "five", "six"])

View File

@ -55,10 +55,11 @@ from haystack.nodes import (
TransformersSummarizer,
TransformersTranslator,
QuestionGenerator,
PromptTemplate,
)
from haystack.modeling.infer import Inferencer, QAInferencer
from haystack.nodes.prompt import PromptNode, PromptModel
from haystack.schema import Document
from haystack.schema import Document, FilterType
from haystack.utils.import_utils import _optional_component_not_installed
try:
@ -272,11 +273,30 @@ class MockDocumentStore(BaseDocumentStore):
class MockRetriever(BaseRetriever):
outgoing_edges = 1
def retrieve(self, query: str, top_k: int):
pass
def retrieve(
self,
query: str,
filters: Optional[FilterType] = None,
top_k: Optional[int] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
scale_score: Optional[bool] = None,
document_store: Optional[BaseDocumentStore] = None,
) -> List[Document]:
return []
def retrieve_batch(self, queries: List[str], top_k: int):
pass
def retrieve_batch(
self,
queries: List[str],
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
top_k: Optional[int] = None,
index: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
batch_size: Optional[int] = None,
scale_score: Optional[bool] = None,
document_store: Optional[BaseDocumentStore] = None,
) -> List[List[Document]]:
return [[]]
class MockSeq2SegGenerator(BaseGenerator):
@ -343,6 +363,40 @@ class MockReader(BaseReader):
pass
class MockPromptNode(PromptNode):
def __init__(self):
self.default_prompt_template = None
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
return [""]
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
if prompt_template_name == "think-step-by-step":
return PromptTemplate(
name="think-step-by-step",
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
"correctly, you have access to the following tools:\n\n"
"$tool_names_with_descriptions\n\n"
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
"for a final answer, respond with the `Final Answer:`\n\n"
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: [$tool_names]\n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
"---\n\n"
"Question: $query\n"
"Thought: Let's think step-by-step, I first need to $generated_text",
)
else:
return PromptTemplate(name="", prompt_text="")
#
# Document collections
#