mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-09 13:56:58 +00:00
feat: Add Agent (#4148)
* initial Agent implementation * mypy and pylint fixes * add missing ABC import * improved prompt template * refactor and shorten run method * refactor and shorten run method * add tests for extracting * fix mixed up tool_input/observation & make tests more robust * fix bug with max_iterations and update prompt template * allow setting prompt_template in Agent init * remove example yml for agent * add final prediction to transcript * add transcript to errors and accept PromptTemplate in init * simplify if else to elif Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com> * add checks for max_iter<2 and empty list returned by prompt node --------- Co-authored-by: Silvano Cerza <3314350+silvanocerza@users.noreply.github.com>
This commit is contained in:
parent
bde01cbf1f
commit
5ce7a404ac
3
.github/labeler.yml
vendored
3
.github/labeler.yml
vendored
@ -31,6 +31,9 @@ topic:reader:
|
||||
topic:retriever:
|
||||
- haystack/nodes/retriever/*
|
||||
- test/nodes/test_retriever.py
|
||||
topic:agent:
|
||||
- haystack/agents/*
|
||||
- test/agents/*
|
||||
topic:pipeline:
|
||||
- haystack/pipelines/*
|
||||
- haystack/nodes/other/*
|
||||
|
||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@ -411,6 +411,7 @@ jobs:
|
||||
folder:
|
||||
- "nodes"
|
||||
- "pipelines"
|
||||
- "agents"
|
||||
- "modeling"
|
||||
- "others"
|
||||
|
||||
@ -455,6 +456,7 @@ jobs:
|
||||
folder:
|
||||
- "nodes"
|
||||
- "pipelines"
|
||||
- "agents"
|
||||
- "modeling"
|
||||
#- "others"
|
||||
|
||||
@ -504,6 +506,7 @@ jobs:
|
||||
folder:
|
||||
- "nodes"
|
||||
- "pipelines"
|
||||
- "agents"
|
||||
- "modeling"
|
||||
- "others"
|
||||
|
||||
@ -599,6 +602,7 @@ jobs:
|
||||
folder:
|
||||
- "nodes"
|
||||
- "pipelines"
|
||||
- "agents"
|
||||
- "modeling"
|
||||
- "others"
|
||||
|
||||
|
||||
2
haystack/agents/__init__.py
Normal file
2
haystack/agents/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from haystack.agents.base import Agent
|
||||
from haystack.agents.base import Tool
|
||||
290
haystack/agents/base.py
Normal file
290
haystack/agents/base.py
Normal file
@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Union, Dict, Tuple, Any
|
||||
|
||||
from haystack import Pipeline, BaseComponent, Answer
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||
from haystack.pipelines import (
|
||||
BaseStandardPipeline,
|
||||
ExtractiveQAPipeline,
|
||||
DocumentSearchPipeline,
|
||||
GenerativeQAPipeline,
|
||||
SearchSummarizationPipeline,
|
||||
FAQPipeline,
|
||||
TranslationWrapperPipeline,
|
||||
RetrieverQuestionGenerationPipeline,
|
||||
)
|
||||
from haystack.telemetry import send_custom_event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
A tool is a pipeline or node that also has a name and description. When you add a Tool to an Agent, the Agent can
|
||||
invoke the underlying pipeline or node to answer questions. The wording of the description is important for the
|
||||
Agent to decide which tool is most useful for a given question.
|
||||
|
||||
:param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates.
|
||||
The name should be short, ideally one token, and a good description of what the tool can do, for example
|
||||
"Calculator" or "Search".
|
||||
:param pipeline_or_node: The pipeline or node to run when this tool is invoked by an Agent.
|
||||
:param description: A description of what the tool is useful for. An Agent can use this description for the decision
|
||||
when to use which tool. For example, a tool for calculations can be described by "useful for when you need to
|
||||
answer questions about math".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
pipeline_or_node: Union[
|
||||
BaseComponent,
|
||||
Pipeline,
|
||||
ExtractiveQAPipeline,
|
||||
DocumentSearchPipeline,
|
||||
GenerativeQAPipeline,
|
||||
SearchSummarizationPipeline,
|
||||
FAQPipeline,
|
||||
TranslationWrapperPipeline,
|
||||
RetrieverQuestionGenerationPipeline,
|
||||
],
|
||||
description: str,
|
||||
):
|
||||
self.name = name
|
||||
self.pipeline_or_node = pipeline_or_node
|
||||
self.description = description
|
||||
|
||||
|
||||
class Agent:
|
||||
"""
|
||||
An Agent answers queries by choosing between different tools, which are pipelines or nodes. It uses a large
|
||||
language model (LLM) to generate a thought based on the query, choose a tool, and generate the input for the
|
||||
tool. Based on the result returned by the tool the Agent can either stop if it now knows the answer or repeat the
|
||||
process of 1) generate thought, 2) choose tool, 3) generate input.
|
||||
|
||||
Agents are useful for questions containing multiple subquestions that can be answered step-by-step (Multihop QA)
|
||||
using multiple pipelines and nodes as tools.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_node: PromptNode,
|
||||
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
|
||||
tools: Optional[List[Tool]] = None,
|
||||
max_iterations: int = 5,
|
||||
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
|
||||
final_answer_pattern: str = r"Final Answer:\s*(\w+)\s*",
|
||||
):
|
||||
"""
|
||||
Creates an Agent instance.
|
||||
|
||||
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration.
|
||||
:param prompt_template: The name of a PromptTemplate supported by the PromptNode or a new PromptTemplate. It is used for generating thoughts and running Tools to answer queries step-by-step.
|
||||
:param tools: A List of Tools that the Agent can choose to run. If no Tools are provided, they need to be added with `add_tool()` before you can use the Agent.
|
||||
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
|
||||
Set at least to 2 so that the Agent can run one Tool and then infer it knows the final answer. Default 5.
|
||||
:param tool_pattern: A regular expression to extract the name of the Tool and the corresponding input from the text generated by the Agent.
|
||||
:param final_answer_pattern: A regular expression to extract final answer from the text generated by the Agent.
|
||||
"""
|
||||
self.prompt_node = prompt_node
|
||||
self.prompt_template = (
|
||||
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
|
||||
)
|
||||
self.tools = {tool.name: tool for tool in tools} if tools else {}
|
||||
self.tool_names = ", ".join(self.tools.keys())
|
||||
self.tool_names_with_descriptions = "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
self.tool_pattern = tool_pattern
|
||||
self.final_answer_pattern = final_answer_pattern
|
||||
send_custom_event(event=f"{type(self).__name__} initialized")
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
"""
|
||||
Add the provided tool to the Agent and update the template for the Agent's PromptNode.
|
||||
|
||||
:param tool: The Tool to add to the Agent. Any previously added tool with the same name will be overwritten.
|
||||
"""
|
||||
self.tools[tool.name] = tool
|
||||
self.tool_names = ", ".join(self.tools.keys())
|
||||
self.tool_names_with_descriptions = "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
|
||||
)
|
||||
|
||||
def has_tool(self, tool_name: str):
|
||||
"""
|
||||
Check whether the Agent has a Tool registered under the provided tool name.
|
||||
|
||||
:param tool_name: The name of the Tool for which to check whether the Agent has it.
|
||||
"""
|
||||
return tool_name in self.tools
|
||||
|
||||
def run(
|
||||
self, query: str, max_iterations: Optional[int] = None, params: Optional[dict] = None
|
||||
) -> Dict[str, Union[str, List[Answer]]]:
|
||||
"""
|
||||
Runs the Agent given a query and optional parameters to pass on to the tools used. The result is in the
|
||||
same format as a pipeline's result: a dictionary with a key `answers` containing a List of Answers
|
||||
|
||||
:param query: The search query.
|
||||
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
|
||||
If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer.
|
||||
:param params: A dictionary of parameters that you want to pass to those tools that are pipelines.
|
||||
To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`.
|
||||
To pass a parameter to targeted nodes in those pipelines, use:
|
||||
`{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`.
|
||||
Parameters can only be passed to tools that are pipelines but not nodes.
|
||||
"""
|
||||
if not self.tools:
|
||||
raise AgentError(
|
||||
"Agents without tools cannot be run. Add at least one tool using `add_tool()` or set the parameter `tools` when initializing the Agent."
|
||||
)
|
||||
if max_iterations is None:
|
||||
max_iterations = self.max_iterations
|
||||
if max_iterations < 2:
|
||||
raise AgentError(
|
||||
f"max_iterations was set to {max_iterations} but it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer."
|
||||
)
|
||||
|
||||
transcript = self._get_initial_transcript(query=query)
|
||||
# Generate a thought with a plan what to do, choose a tool, generate input for it, and run it.
|
||||
# Repeat this until the final answer is found or the maximum number of iterations is reached.
|
||||
for _ in range(max_iterations):
|
||||
preds = self.prompt_node(transcript)
|
||||
if not preds:
|
||||
raise AgentError(f"No output was generated by the Agent. Transcript:\n{transcript}")
|
||||
|
||||
# Try to extract final answer or tool name and input from the generated text and update the transcript
|
||||
final_answer = self._extract_final_answer(pred=preds[0])
|
||||
if final_answer is not None:
|
||||
transcript += preds[0]
|
||||
return self._format_answer(query=query, transcript=transcript, answer=final_answer)
|
||||
tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0])
|
||||
if tool_name is None or tool_input is None:
|
||||
raise AgentError(f"Wrong output format. Transcript:\n{transcript}")
|
||||
|
||||
result = self._run_tool(tool_name=tool_name, tool_input=tool_input, transcript=transcript, params=params)
|
||||
observation = self._extract_observation(result)
|
||||
transcript += f"{preds[0]}\nObservation: {observation}\nThought: Now that I know that {observation} is the answer to {tool_name} {tool_input}, I "
|
||||
|
||||
logger.warning(
|
||||
"Maximum number of iterations (%s) reached for query (%s). Increase max_iterations "
|
||||
"or no answer can be provided for this query.",
|
||||
max_iterations,
|
||||
query,
|
||||
)
|
||||
return self._format_answer(query=query, transcript=transcript, answer="")
|
||||
|
||||
def run_batch(
|
||||
self, queries: List[str], max_iterations: Optional[int] = None, params: Optional[dict] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Runs the Agent in a batch mode
|
||||
|
||||
:param queries: List of search queries.
|
||||
:param max_iterations: The number of times the Agent can run a tool plus then once infer it knows the final answer.
|
||||
If set it should be at least 2 so that the Agent can run one tool and then infer it knows the final answer.
|
||||
:param params: A dictionary of parameters that you want to pass to those tools that are pipelines.
|
||||
To pass a parameter to all nodes in those pipelines, use: `{"top_k": 10}`.
|
||||
To pass a parameter to targeted nodes in those pipelines, use:
|
||||
`{"Retriever": {"top_k": 10}, "Reader": {"top_k": 3}}`.
|
||||
Parameters can only be passed to tools that are pipelines but not nodes.
|
||||
"""
|
||||
results: Dict = {"queries": [], "answers": [], "transcripts": []}
|
||||
for query in queries:
|
||||
result = self.run(query=query, max_iterations=max_iterations, params=params)
|
||||
results["queries"].append(result["query"])
|
||||
results["answers"].append(result["answers"])
|
||||
results["transcripts"].append(result["transcript"])
|
||||
|
||||
return results
|
||||
|
||||
def _run_tool(
|
||||
self, tool_name: str, tool_input: str, transcript: str, params: Optional[dict] = None
|
||||
) -> Union[Tuple[Dict[str, Any], str], Dict[str, Any]]:
|
||||
if not self.has_tool(tool_name):
|
||||
raise AgentError(
|
||||
f"Cannot use the tool {tool_name} because it is not in the list of added tools {self.tools.keys()}."
|
||||
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
|
||||
f"Transcript:\n{transcript}"
|
||||
)
|
||||
|
||||
pipeline_or_node = self.tools[tool_name].pipeline_or_node
|
||||
# We can only pass params to pipelines but not to nodes
|
||||
if isinstance(pipeline_or_node, (Pipeline, BaseStandardPipeline)):
|
||||
result = pipeline_or_node.run(query=tool_input, params=params)
|
||||
elif isinstance(pipeline_or_node, BaseRetriever):
|
||||
result = pipeline_or_node.run(query=tool_input, root_node="Query")
|
||||
else:
|
||||
result = pipeline_or_node.run(query=tool_input)
|
||||
return result
|
||||
|
||||
def _extract_observation(self, result: Union[Tuple[Dict[str, Any], str], Dict[str, Any]]) -> str:
|
||||
observation = ""
|
||||
# if result was returned by a node it is of type tuple. We use only the output but not the name of the output.
|
||||
# if result was returned by a pipeline it is of type dict that we can use directly.
|
||||
if isinstance(result, tuple):
|
||||
result = result[0]
|
||||
if isinstance(result, dict):
|
||||
if result.get("results", None):
|
||||
observation = result["results"][0]
|
||||
elif result.get("answers", None):
|
||||
observation = result["answers"][0].answer
|
||||
elif result.get("documents", None):
|
||||
observation = result["documents"][0].content
|
||||
|
||||
# observation remains "" if no result/answer/document was returned
|
||||
return observation
|
||||
|
||||
def _extract_tool_name_and_tool_input(self, pred: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Parse the tool name and the tool input from the prediction output of the Agent's PromptNode.
|
||||
|
||||
:param pred: Prediction output of the Agent's PromptNode from which to parse the tool and tool input
|
||||
"""
|
||||
tool_match = re.search(self.tool_pattern, pred)
|
||||
if tool_match:
|
||||
tool_name = tool_match.group(1)
|
||||
tool_input = tool_match.group(3)
|
||||
return tool_name.strip('" []').strip(), tool_input.strip('" ')
|
||||
return None, None
|
||||
|
||||
def _extract_final_answer(self, pred: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the final answer from the prediction output of the Agent's PromptNode.
|
||||
|
||||
:param pred: Prediction output of the Agent's PromptNode from which to parse the final answer
|
||||
"""
|
||||
final_answer_match = re.search(self.final_answer_pattern, pred)
|
||||
if final_answer_match:
|
||||
final_answer = final_answer_match.group(1)
|
||||
return final_answer.strip('" ')
|
||||
return None
|
||||
|
||||
def _format_answer(self, query: str, answer: str, transcript: str) -> Dict[str, Union[str, List[Answer]]]:
|
||||
"""
|
||||
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
|
||||
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
|
||||
|
||||
:param query: The search query.
|
||||
:param answer: The final answer returned by the Agent. An empty string corresponds to no answer.
|
||||
:param transcript: The text generated by the Agent and the initial filled template for debug purposes.
|
||||
"""
|
||||
return {"query": query, "answers": [Answer(answer=answer, type="generative")], "transcript": transcript}
|
||||
|
||||
def _get_initial_transcript(self, query: str):
|
||||
"""
|
||||
Fills the Agent's PromptTemplate with the query, tool names and descriptions
|
||||
|
||||
:param query: The search query.
|
||||
"""
|
||||
return next(
|
||||
self.prompt_template.fill(
|
||||
query=query, tool_names=self.tool_names, tool_names_with_descriptions=self.tool_names_with_descriptions
|
||||
),
|
||||
"",
|
||||
)
|
||||
@ -47,6 +47,15 @@ class ModelingError(HaystackError):
|
||||
super().__init__(message=message, docs_link=docs_link)
|
||||
|
||||
|
||||
class AgentError(HaystackError):
|
||||
"""Exception for issues raised within an agent"""
|
||||
|
||||
def __init__(
|
||||
self, message: Optional[str] = None, docs_link: Optional[str] = "https://docs.haystack.deepset.ai/docs/agents"
|
||||
):
|
||||
super().__init__(message=message, docs_link=docs_link)
|
||||
|
||||
|
||||
class PipelineError(HaystackError):
|
||||
"""Exception for issues raised within a pipeline"""
|
||||
|
||||
|
||||
@ -689,6 +689,27 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
|
||||
name="translation",
|
||||
prompt_text="Translate the following context to $target_language. Context: $documents; Translation:",
|
||||
),
|
||||
PromptTemplate(
|
||||
name="zero-shot-react",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"$tool_names_with_descriptions\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
||||
"for a final answer, respond with the `Final Answer:`\n\n"
|
||||
"Use the following format:\n\n"
|
||||
"Question: the question to be answered\n"
|
||||
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
||||
"Tool: [$tool_names]\n"
|
||||
"Tool Input: the input for the tool\n"
|
||||
"Observation: the tool will respond with the result\n"
|
||||
"...\n"
|
||||
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: $query\n"
|
||||
"Thought: Let's think step-by-step, I first need to ",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
0
test/agents/__init__.py
Normal file
0
test/agents/__init__.py
Normal file
251
test/agents/test_agent.py
Normal file
251
test/agents/test_agent.py
Normal file
@ -0,0 +1,251 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import BaseComponent, Answer
|
||||
from haystack.agents import Agent
|
||||
from haystack.agents.base import Tool
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
||||
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
|
||||
from test.conftest import MockRetriever, MockPromptNode
|
||||
|
||||
|
||||
def test_add_and_overwrite_tool():
|
||||
# Add a Node as a Tool to an Agent
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
retriever = MockRetriever()
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever,
|
||||
description="useful for when you need to " "retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
assert len(agent.tools) == 1
|
||||
assert "Retriever" in agent.tools
|
||||
assert agent.has_tool(tool_name="Retriever")
|
||||
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseComponent)
|
||||
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever,
|
||||
description="useful for when you need to retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
|
||||
# Add a Pipeline as a Tool to an Agent and overwrite the previously added Tool
|
||||
retriever_pipeline = DocumentSearchPipeline(MockRetriever())
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever_pipeline,
|
||||
description="useful for when you need to retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
assert len(agent.tools) == 1
|
||||
assert "Retriever" in agent.tools
|
||||
assert agent.has_tool(tool_name="Retriever")
|
||||
assert isinstance(agent.tools["Retriever"].pipeline_or_node, BaseStandardPipeline)
|
||||
|
||||
|
||||
def test_agent_chooses_no_action():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
retriever = MockRetriever()
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever,
|
||||
description="useful for when you need to retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
with pytest.raises(AgentError, match=r"Wrong output format.*"):
|
||||
agent.run("How many letters does the name of the town where Christelle lives have?")
|
||||
|
||||
|
||||
def test_max_iterations(caplog, monkeypatch):
|
||||
# Run an Agent and stop because max_iterations is reached
|
||||
agent = Agent(prompt_node=MockPromptNode(), max_iterations=3)
|
||||
retriever = MockRetriever()
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever,
|
||||
description="useful for when you need to retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
|
||||
# Let the Agent always choose "Retriever" as the Tool with "" as input
|
||||
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
|
||||
return "Retriever", ""
|
||||
|
||||
monkeypatch.setattr(Agent, "_extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
|
||||
|
||||
# Using max_iterations as specified in the Agent's init method
|
||||
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
||||
result = agent.run("Where does Christelle live?")
|
||||
assert result["answers"] == [Answer(answer="", type="generative")]
|
||||
assert "Maximum number of iterations (3) reached" in caplog.text
|
||||
|
||||
# Setting max_iterations in the Agent's run method
|
||||
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
||||
result = agent.run("Where does Christelle live?", max_iterations=2)
|
||||
assert result["answers"] == [Answer(answer="", type="generative")]
|
||||
assert "Maximum number of iterations (2) reached" in caplog.text
|
||||
|
||||
|
||||
def test_run_tool():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
retriever = MockRetriever()
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Retriever",
|
||||
pipeline_or_node=retriever,
|
||||
description="useful for when you need to retrieve documents from your index",
|
||||
)
|
||||
)
|
||||
result = agent._run_tool(tool_name="Retriever", tool_input="", transcript="")
|
||||
assert result[0]["documents"] == []
|
||||
|
||||
|
||||
def test_extract_observation():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
observation = agent._extract_observation(
|
||||
result={
|
||||
"answers": [
|
||||
Answer(answer="first answer", type="generative"),
|
||||
Answer(answer="second answer", type="generative"),
|
||||
]
|
||||
}
|
||||
)
|
||||
assert observation == "first answer"
|
||||
|
||||
|
||||
def test_extract_tool_name_and_tool_input():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
|
||||
pred = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
|
||||
tool_name, tool_input = agent._extract_tool_name_and_tool_input(pred)
|
||||
assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born"
|
||||
|
||||
|
||||
def test_extract_final_answer():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
|
||||
pred = "have the final answer to the question.\nFinal Answer: Florida"
|
||||
final_answer = agent._extract_final_answer(pred)
|
||||
assert final_answer == "Florida"
|
||||
|
||||
|
||||
def test_format_answer():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
formatted_answer = agent._format_answer(query="query", answer="answer", transcript="transcript")
|
||||
assert formatted_answer["query"] == "query"
|
||||
assert formatted_answer["answers"] == [Answer(answer="answer", type="generative")]
|
||||
assert formatted_answer["transcript"] == "transcript"
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
|
||||
search = ExtractiveQAPipeline(reader, retriever_with_docs)
|
||||
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
|
||||
counter = PromptNode(
|
||||
model_name_or_path=prompt_model,
|
||||
default_prompt_template=PromptTemplate(
|
||||
name="calculator_template",
|
||||
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
|
||||
"Word: Rome\nLength: 4\n"
|
||||
"Word: Arles\nLength: 5\n"
|
||||
"Word: Berlin\nLength: 6\n"
|
||||
"Word: $query?\nLength: ",
|
||||
prompt_params=["query"],
|
||||
),
|
||||
)
|
||||
|
||||
agent = Agent(prompt_node=prompt_node)
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Search",
|
||||
pipeline_or_node=search,
|
||||
description="useful for when you need to answer "
|
||||
"questions about where people live. You "
|
||||
"should ask targeted questions",
|
||||
)
|
||||
)
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Count",
|
||||
pipeline_or_node=counter,
|
||||
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
|
||||
)
|
||||
)
|
||||
|
||||
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
|
||||
result = agent.run("How many characters are in the word Madrid?")
|
||||
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
|
||||
|
||||
result = agent.run("How many letters does the name of the town where Christelle lives have?")
|
||||
assert any(digit in result["answers"][0].answer for digit in ["5", "6", "five", "six"])
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
||||
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
|
||||
@pytest.mark.skipif(
|
||||
not os.environ.get("OPENAI_API_KEY", None),
|
||||
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
||||
)
|
||||
def test_agent_run_batch(reader, retriever_with_docs, document_store_with_docs):
|
||||
search = ExtractiveQAPipeline(reader, retriever_with_docs)
|
||||
prompt_model = PromptModel(model_name_or_path="text-davinci-003", api_key=os.environ.get("OPENAI_API_KEY"))
|
||||
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
|
||||
counter = PromptNode(
|
||||
model_name_or_path=prompt_model,
|
||||
default_prompt_template=PromptTemplate(
|
||||
name="calculator_template",
|
||||
prompt_text="When I give you a word, respond with the number of characters that this word contains.\n"
|
||||
"Word: Rome\nLength: 4\n"
|
||||
"Word: Arles\nLength: 5\n"
|
||||
"Word: Berlin\nLength: 6\n"
|
||||
"Word: $query?\nLength: ",
|
||||
prompt_params=["query"],
|
||||
),
|
||||
)
|
||||
|
||||
agent = Agent(prompt_node=prompt_node)
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Search",
|
||||
pipeline_or_node=search,
|
||||
description="useful for when you need to answer "
|
||||
"questions about where people live. You "
|
||||
"should ask targeted questions",
|
||||
)
|
||||
)
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
name="Count",
|
||||
pipeline_or_node=counter,
|
||||
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
|
||||
)
|
||||
)
|
||||
|
||||
results = agent.run_batch(
|
||||
queries=[
|
||||
"How many characters are in the word Madrid?",
|
||||
"How many letters does the name of the town where Christelle lives have?",
|
||||
]
|
||||
)
|
||||
# TODO Replace Count tool once more tools are implemented so that we do not need to account for off-by-one errors
|
||||
assert any(digit in results["answers"][0][0].answer for digit in ["5", "6", "five", "six"])
|
||||
assert any(digit in results["answers"][1][0].answer for digit in ["5", "6", "five", "six"])
|
||||
@ -55,10 +55,11 @@ from haystack.nodes import (
|
||||
TransformersSummarizer,
|
||||
TransformersTranslator,
|
||||
QuestionGenerator,
|
||||
PromptTemplate,
|
||||
)
|
||||
from haystack.modeling.infer import Inferencer, QAInferencer
|
||||
from haystack.nodes.prompt import PromptNode, PromptModel
|
||||
from haystack.schema import Document
|
||||
from haystack.schema import Document, FilterType
|
||||
from haystack.utils.import_utils import _optional_component_not_installed
|
||||
|
||||
try:
|
||||
@ -272,11 +273,30 @@ class MockDocumentStore(BaseDocumentStore):
|
||||
class MockRetriever(BaseRetriever):
|
||||
outgoing_edges = 1
|
||||
|
||||
def retrieve(self, query: str, top_k: int):
|
||||
pass
|
||||
def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
filters: Optional[FilterType] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[Document]:
|
||||
return []
|
||||
|
||||
def retrieve_batch(self, queries: List[str], top_k: int):
|
||||
pass
|
||||
def retrieve_batch(
|
||||
self,
|
||||
queries: List[str],
|
||||
filters: Optional[Union[FilterType, List[Optional[FilterType]]]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
index: Optional[str] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
scale_score: Optional[bool] = None,
|
||||
document_store: Optional[BaseDocumentStore] = None,
|
||||
) -> List[List[Document]]:
|
||||
return [[]]
|
||||
|
||||
|
||||
class MockSeq2SegGenerator(BaseGenerator):
|
||||
@ -343,6 +363,40 @@ class MockReader(BaseReader):
|
||||
pass
|
||||
|
||||
|
||||
class MockPromptNode(PromptNode):
|
||||
def __init__(self):
|
||||
self.default_prompt_template = None
|
||||
|
||||
def prompt(self, prompt_template: Optional[Union[str, PromptTemplate]], *args, **kwargs) -> List[str]:
|
||||
return [""]
|
||||
|
||||
def get_prompt_template(self, prompt_template_name: str) -> PromptTemplate:
|
||||
if prompt_template_name == "think-step-by-step":
|
||||
return PromptTemplate(
|
||||
name="think-step-by-step",
|
||||
prompt_text="You are a helpful and knowledgeable agent. To achieve your goal of answering complex questions "
|
||||
"correctly, you have access to the following tools:\n\n"
|
||||
"$tool_names_with_descriptions\n\n"
|
||||
"To answer questions, you'll need to go through multiple steps involving step-by-step thinking and "
|
||||
"selecting appropriate tools and their inputs; tools will respond with observations. When you are ready "
|
||||
"for a final answer, respond with the `Final Answer:`\n\n"
|
||||
"Use the following format:\n\n"
|
||||
"Question: the question to be answered\n"
|
||||
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
|
||||
"Tool: [$tool_names]\n"
|
||||
"Tool Input: the input for the tool\n"
|
||||
"Observation: the tool will respond with the result\n"
|
||||
"...\n"
|
||||
"Final Answer: the final answer to the question, make it short (1-5 words)\n\n"
|
||||
"Thought, Tool, Tool Input, and Observation steps can be repeated multiple times, but sometimes we can find an answer in the first pass\n"
|
||||
"---\n\n"
|
||||
"Question: $query\n"
|
||||
"Thought: Let's think step-by-step, I first need to $generated_text",
|
||||
)
|
||||
else:
|
||||
return PromptTemplate(name="", prompt_text="")
|
||||
|
||||
|
||||
#
|
||||
# Document collections
|
||||
#
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user