refactor: Add AgentStep (#4431)

This commit is contained in:
Vladimir Blagojevic 2023-03-17 18:21:14 +01:00 committed by GitHub
parent 4d19bd13a5
commit 3272e2b9fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 303 additions and 113 deletions

View File

@ -1,2 +1,3 @@
from haystack.agents.agent_step import AgentStep
from haystack.agents.base import Agent
from haystack.agents.base import Tool

View File

@ -0,0 +1,157 @@
from __future__ import annotations
import logging
import re
from typing import Optional, Dict, Tuple, Any
from haystack import Answer
from haystack.errors import AgentError
logger = logging.getLogger(__name__)
class AgentStep:
"""
The AgentStep class represents a single step in the execution of an agent.
"""
def __init__(
self,
current_step: int = 1,
max_steps: int = 10,
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
prompt_node_response: str = "",
transcript: str = "",
):
"""
:param current_step: The current step in the execution of the agent.
:param max_steps: The maximum number of steps the agent can execute.
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response.
:param prompt_node_response: The PromptNode response received.
:param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the
text it generated during execution up to this step. The transcript is used to generate the next prompt.
"""
self.current_step = current_step
self.max_steps = max_steps
self.final_answer_pattern = final_answer_pattern
self.prompt_node_response = prompt_node_response
self.transcript = transcript
def prepare_prompt(self):
"""
Prepares the prompt for the next step.
"""
return self.transcript
def create_next_step(self, prompt_node_response: Any) -> AgentStep:
"""
Creates the next agent step based on the current step and the PromptNode response.
:param prompt_node_response: The PromptNode response received.
"""
if not isinstance(prompt_node_response, list):
raise AgentError(
f"Agent output must be a list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)
if not prompt_node_response:
raise AgentError(
f"Agent output must be a non empty list of str, but {prompt_node_response} received. "
f"Transcript:\n{self.transcript}"
)
return AgentStep(
current_step=self.current_step + 1,
max_steps=self.max_steps,
final_answer_pattern=self.final_answer_pattern,
prompt_node_response=prompt_node_response[0],
transcript=self.transcript,
)
def extract_tool_name_and_tool_input(self, tool_pattern: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the PromptNode response.
:param tool_pattern: The regex pattern to extract the tool name and the tool input from the PromptNode response.
:return: A tuple containing the tool name and the tool input.
"""
tool_match = re.search(tool_pattern, self.prompt_node_response)
if tool_match:
tool_name = tool_match.group(1)
tool_input = tool_match.group(3)
return tool_name.strip('" []\n').strip(), tool_input.strip('" \n')
return None, None
def final_answer(self, query: str) -> Dict[str, Any]:
"""
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
:param query: The search query
"""
answer: Dict[str, Any] = {
"query": query,
"answers": [Answer(answer="", type="generative")],
"transcript": self.transcript,
}
if self.current_step >= self.max_steps:
logger.warning(
"Maximum number of iterations (%s) reached for query (%s). Increase max_steps "
"or no answer can be provided for this query.",
self.max_steps,
query,
)
else:
final_answer = self.extract_final_answer()
if not final_answer:
logger.warning(
"Final answer pattern (%s) not found in PromptNode response (%s).",
self.final_answer_pattern,
self.prompt_node_response,
)
else:
answer = {
"query": query,
"answers": [Answer(answer=final_answer, type="generative")],
"transcript": self.transcript,
}
return answer
def extract_final_answer(self) -> Optional[str]:
"""
Parse the final answer from the PromptNode response.
:return: The final answer.
"""
if not self.is_last():
raise AgentError("Cannot extract final answer from non terminal step.")
final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ')
return None
def is_final_answer_pattern_found(self) -> bool:
"""
Check if the final answer pattern was found in PromptNode response.
:return: True if the final answer pattern was found in PromptNode response, False otherwise.
"""
return bool(re.search(self.final_answer_pattern, self.prompt_node_response))
def is_last(self) -> bool:
"""
Check if this is the last step of the Agent.
:return: True if this is the last step of the Agent, False otherwise.
"""
return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps
def completed(self, observation: Optional[str]):
"""
Update the transcript with the observation
:param observation: received observation from the Agent environment.
"""
self.transcript += (
f"{self.prompt_node_response}\nObservation: {observation}\nThought:"
if observation
else self.prompt_node_response
)

View File

@ -2,9 +2,10 @@ from __future__ import annotations
import logging
import re
from typing import List, Optional, Union, Dict, Tuple, Any
from typing import List, Optional, Union, Dict, Any
from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.agents.agent_step import AgentStep
from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.pipelines import (
@ -24,10 +25,12 @@ logger = logging.getLogger(__name__)
class Tool:
"""
Agent uses tools to find the best answer. A tool is a pipeline or a node. When you add a tool to an Agent, the Agent can
invoke the underlying pipeline or node to answer questions.
Agent uses tools to find the best answer. A tool is a pipeline or a node. When you add a tool to an Agent, the Agent
can invoke the underlying pipeline or node to answer questions.
You must provide a name and a description for each tool. The name should be short and should indicate what the tool can do. The description should explain what the tool is useful for. The Agent uses the description to decide when to use a tool, so the wording you use is important.
You must provide a name and a description for each tool. The name should be short and should indicate what the tool
can do. The description should explain what the tool is useful for. The Agent uses the description to decide when
to use a tool, so the wording you use is important.
:param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates.
The name should be short, ideally one token, and a good description of what the tool can do, for example:
@ -58,7 +61,8 @@ class Tool:
):
if re.search(r"\W", name):
raise ValueError(
f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and underscores (_)."
f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and "
f"underscores (_)."
)
self.name = name
self.pipeline_or_node = pipeline_or_node
@ -101,14 +105,16 @@ class Tool:
class Agent:
"""
An Agent answers queries using the tools you give to it. The tools are pipelines or nodes. The Agent uses a large
language model (LLM) through the PromptNode you initialize it with. To answer a query, the Agent follows this sequence:
language model (LLM) through the PromptNode you initialize it with. To answer a query, the Agent follows this
sequence:
1. It generates a thought based on the query.
2. It decides which tool to use.
3. It generates the input for the tool.
4. Based on the output it gets from the tool, the Agent can either stop if it now knows the answer or repeat the
process of 1) generate thought, 2) choose tool, 3) generate input.
Agents are useful for questions containing multiple subquestions that can be answered step-by-step (Multihop QA)
Agents are useful for questions containing multiple sub questions that can be answered step-by-step (Multi-hop QA)
using multiple pipelines and nodes as tools.
"""
@ -117,19 +123,25 @@ class Agent:
prompt_node: PromptNode,
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
tools: Optional[List[Tool]] = None,
max_iterations: int = 5,
max_steps: int = 8,
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
final_answer_pattern: str = r"Final Answer:\s*(\w+)\s*",
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
):
"""
Creates an Agent instance.
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration.
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a new template in a similar format.
:param tools: A list of tools the Agent can run. If you don't specify any tools here, you must add them with `add_tool()` before running the Agent.
:param max_iterations: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer. The default is 5.
:param tool_pattern: A regular expression to extract the name of the tool and the corresponding input from the text the Agent generated.
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to
it in each iteration.
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
new template in a similar format.
:param tools: A list of tools the Agent can run. If you don't specify any tools here, you must add them
with `add_tool()` before running the Agent.
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer.
The default is 5.
:param tool_pattern: A regular expression to extract the name of the tool and the corresponding input from the
text the Agent generated.
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
"""
self.prompt_node = prompt_node
@ -141,7 +153,7 @@ class Agent:
self.tool_names_with_descriptions = "\n".join(
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
)
self.max_iterations = max_iterations
self.max_steps = max_steps
self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern
send_custom_event(event=f"{type(self).__name__} initialized")
@ -150,7 +162,8 @@ class Agent:
"""
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
:param tool: The tool to add to the Agent. Any previously added tool with the same name will be overwritten. Example:
:param tool: The tool to add to the Agent. Any previously added tool with the same name will be overwritten.
Example:
`agent.add_tool(
Tool(
name="Calculator",
@ -174,15 +187,16 @@ class Agent:
return tool_name in self.tools
def run(
self, query: str, max_iterations: Optional[int] = None, params: Optional[dict] = None
self, query: str, max_steps: Optional[int] = None, params: Optional[dict] = None
) -> Dict[str, Union[str, List[Answer]]]:
"""
Runs the Agent given a query and optional parameters to pass on to the tools used. The result is in the
same format as a pipeline's result: a dictionary with a key `answers` containing a list of answers.
:param query: The search query.
:param max_iterations: The number of times the Agent can run a tool +1 to infer it knows the final answer.
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the final answer.
:param query: The search query
:param max_steps: The number of times the Agent can run a tool +1 to infer it knows the final answer.
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the
final answer.
:param params: A dictionary of parameters you want to pass to the tools that are pipelines.
To pass a parameter to all nodes in those pipelines, use the format: `{"top_k": 10}`.
To pass a parameter to targeted nodes in those pipelines, use the format:
@ -191,49 +205,57 @@ class Agent:
"""
if not self.tools:
raise AgentError(
"An Agent needs tools to run. Add at least one tool using `add_tool()` or set the parameter `tools` when initializing the Agent."
"An Agent needs tools to run. Add at least one tool using `add_tool()` or set the parameter `tools` "
"when initializing the Agent."
)
if max_iterations is None:
max_iterations = self.max_iterations
if max_iterations < 2:
if max_steps is None:
max_steps = self.max_steps
if max_steps < 2:
raise AgentError(
f"max_iterations must be at least 2 to let the Agent use a tool once and then infer it knows the final answer. It was set to {max_iterations}."
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
f"answer. It was set to {max_steps}."
)
agent_step = self._create_first_step(query, max_steps)
while not agent_step.is_last():
agent_step = self._step(agent_step, params)
return agent_step.final_answer(query=query)
def _create_first_step(self, query: str, max_steps: int = 10):
transcript = self._get_initial_transcript(query=query)
# Generate a thought with a plan what to do, choose a tool, generate input for it, and run it.
# Repeat this until the final answer is found or the maximum number of iterations is reached.
for _ in range(max_iterations):
preds = self.prompt_node(transcript)
if not preds:
raise AgentError(f"The Agent generated no output. Transcript:\n{transcript}")
# Try to extract final answer or tool name and input from the generated text and update the transcript
final_answer = self._extract_final_answer(pred=preds[0])
if final_answer is not None:
transcript += preds[0]
return self._format_answer(query=query, transcript=transcript, answer=final_answer)
tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0])
observation = self._run_tool(tool_name, tool_input, transcript + preds[0], params)
transcript += f"{preds[0]}\nObservation: {observation}\nThought:"
logger.warning(
"The Agent reached the maximum number of iterations (%s) for query (%s). Increase the max_iterations parameter "
"or the Agent won't be able to provide an answer to this query.",
max_iterations,
query,
return AgentStep(
current_step=1,
max_steps=max_steps,
final_answer_pattern=self.final_answer_pattern,
prompt_node_response="", # no LLM response for the first step
transcript=transcript,
)
return self._format_answer(query=query, transcript=transcript, answer="")
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
# plan next step using the LLM
prompt_node_response = self.prompt_node(current_step.prepare_prompt())
# from the LLM response, create the next step
next_step = current_step.create_next_step(prompt_node_response)
# run the tool selected by the LLM
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
# update the next step with the observation
next_step.completed(observation)
return next_step
def run_batch(
self, queries: List[str], max_iterations: Optional[int] = None, params: Optional[dict] = None
self, queries: List[str], max_steps: Optional[int] = None, params: Optional[dict] = None
) -> Dict[str, str]:
"""
Runs the Agent in a batch mode.
:param queries: List of search queries.
:param max_iterations: The number of times the Agent can run a tool +1 to infer it knows the final answer.
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the final answer.
:param max_steps: The number of times the Agent can run a tool +1 to infer it knows the final answer.
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows
the final answer.
:param params: A dictionary of parameters you want to pass to the tools that are pipelines.
To pass a parameter to all nodes in those pipelines, use the format: `{"top_k": 10}`.
To pass a parameter to targeted nodes in those pipelines, use the format:
@ -242,67 +264,30 @@ class Agent:
"""
results: Dict = {"queries": [], "answers": [], "transcripts": []}
for query in queries:
result = self.run(query=query, max_iterations=max_iterations, params=params)
result = self.run(query=query, max_steps=max_steps, params=params)
results["queries"].append(result["query"])
results["answers"].append(result["answers"])
results["transcripts"].append(result["transcript"])
return results
def _run_tool(
self, tool_name: Optional[str], tool_input: Optional[str], transcript: str, params: Optional[dict] = None
) -> str:
def _run_tool(self, next_step: AgentStep, params: Optional[Dict[str, Any]] = None) -> str:
tool_name, tool_input = next_step.extract_tool_name_and_tool_input(self.tool_pattern)
if tool_name is None or tool_input is None:
raise AgentError(
f"Could not identify the next tool or input for that tool from Agent's output. "
f"Adjust the Agent's param 'tool_pattern' or 'prompt_template'. \n"
f"# 'tool_pattern' to identify next tool: {self.tool_pattern} \n"
f"# Transcript:\n{transcript}"
f"# Agent Step:\n{next_step}"
)
if not self.has_tool(tool_name):
raise AgentError(
f"The tool {tool_name} wasn't added to the Agent tools: {self.tools.keys()}."
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
f"Transcript:\n{transcript}"
f"Agent Step::\n{next_step}"
)
return self.tools[tool_name].run(tool_input, params)
def _extract_tool_name_and_tool_input(self, pred: str) -> Tuple[Optional[str], Optional[str]]:
"""
Parse the tool name and the tool input from the prediction output of the Agent's PromptNode.
:param pred: Prediction output of the Agent's PromptNode from which to parse the tool and tool input.
"""
tool_match = re.search(self.tool_pattern, pred)
if tool_match:
tool_name = tool_match.group(1)
tool_input = tool_match.group(3)
return tool_name.strip('" []').strip(), tool_input.strip('" ')
return None, None
def _extract_final_answer(self, pred: str) -> Optional[str]:
"""
Parse the final answer from the prediction output of the Agent's PromptNode.
:param pred: Prediction output of the Agent's PromptNode from which to parse the final answer.
"""
final_answer_match = re.search(self.final_answer_pattern, pred)
if final_answer_match:
final_answer = final_answer_match.group(1)
return final_answer.strip('" ')
return None
def _format_answer(self, query: str, answer: str, transcript: str) -> Dict[str, Union[str, List[Answer]]]:
"""
Formats an answer as a dict containing `query` and `answers`, similar to the output of a Pipeline.
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
:param query: The search query.
:param answer: The final answer the Agent returned. An empty string corresponds to no answer.
:param transcript: The text the Agent generated and the initial, filled template for debug purposes.
"""
return {"query": query, "answers": [Answer(answer=answer, type="generative")], "transcript": transcript}
def _get_initial_transcript(self, query: str):
"""
Fills the Agent's PromptTemplate with the query, tool names, and descriptions.

View File

@ -1,11 +1,12 @@
import logging
import os
import re
from typing import Tuple
import pytest
from haystack import BaseComponent, Answer, Document
from haystack.agents import Agent
from haystack.agents import Agent, AgentStep
from haystack.agents.base import Tool
from haystack.errors import AgentError
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
@ -71,9 +72,9 @@ def test_agent_chooses_no_action():
@pytest.mark.unit
def test_max_iterations(caplog, monkeypatch):
# Run an Agent and stop because max_iterations is reached
agent = Agent(prompt_node=MockPromptNode(), max_iterations=3)
def test_max_steps(caplog, monkeypatch):
# Run an Agent and stop because max_steps is reached
agent = Agent(prompt_node=MockPromptNode(), max_steps=3)
retriever = MockRetriever()
agent.add_tool(
Tool(
@ -88,17 +89,17 @@ def test_max_iterations(caplog, monkeypatch):
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
return "Retriever", ""
monkeypatch.setattr(Agent, "_extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
monkeypatch.setattr(AgentStep, "extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
# Using max_iterations as specified in the Agent's init method
# Using max_steps as specified in the Agent's init method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
result = agent.run("Where does Christelle live?")
assert result["answers"] == [Answer(answer="", type="generative")]
assert "maximum number of iterations (3)" in caplog.text.lower()
# Setting max_iterations in the Agent's run method
# Setting max_steps in the Agent's run method
with caplog.at_level(logging.WARN, logger="haystack.agents"):
result = agent.run("Where does Christelle live?", max_iterations=2)
result = agent.run("Where does Christelle live?", max_steps=2)
assert result["answers"] == [Answer(answer="", type="generative")]
assert "maximum number of iterations (2)" in caplog.text.lower()
@ -115,35 +116,81 @@ def test_run_tool():
output_variable="documents",
)
)
result = agent._run_tool(tool_name="Retriever", tool_input="", transcript="")
pn_response = "need to find out what city he was born.\nTool: Retriever\nTool Input: Where was Jeremy McKinnon born"
step = AgentStep(prompt_node_response=pn_response)
result = agent._run_tool(step)
assert result == "[]" # empty list of documents
@pytest.mark.unit
def test_extract_tool_name_and_tool_input():
agent = Agent(prompt_node=MockPromptNode())
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*'
pn_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
pred = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
tool_name, tool_input = agent._extract_tool_name_and_tool_input(pred)
step = AgentStep(prompt_node_response=pn_response)
tool_name, tool_input = step.extract_tool_name_and_tool_input(tool_pattern=tool_pattern)
assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born"
@pytest.mark.unit
def test_extract_final_answer():
agent = Agent(prompt_node=MockPromptNode())
match_examples = [
"have the final answer to the question.\nFinal Answer: Florida",
"Final Answer: 42 is the answer",
"Final Answer: 1234",
"Final Answer: Answer",
"Final Answer: This list: one and two and three",
"Final Answer:42",
"Final Answer: ",
"Final Answer: The answer is 99 ",
]
expected_answers = [
"Florida",
"42 is the answer",
"1234",
"Answer",
"This list: one and two and three",
"42",
"",
"The answer is 99",
]
pred = "have the final answer to the question.\nFinal Answer: Florida"
final_answer = agent._extract_final_answer(pred)
assert final_answer == "Florida"
for example, expected_answer in zip(match_examples, expected_answers):
agent_step = AgentStep(prompt_node_response=example)
final_answer = agent_step.extract_final_answer()
assert final_answer == expected_answer
@pytest.mark.unit
def test_format_answer():
agent = Agent(prompt_node=MockPromptNode())
formatted_answer = agent._format_answer(query="query", answer="answer", transcript="transcript")
step = AgentStep(prompt_node_response="have the final answer to the question.\nFinal Answer: Florida")
formatted_answer = step.final_answer(query="query")
assert formatted_answer["query"] == "query"
assert formatted_answer["answers"] == [Answer(answer="answer", type="generative")]
assert formatted_answer["transcript"] == "transcript"
assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")]
@pytest.mark.unit
def test_final_answer_regex():
match_examples = [
"Final Answer: 42 is the answer",
"Final Answer: 1234",
"Final Answer: Answer",
"Final Answer: This list: one and two and three",
"Final Answer:42",
"Final Answer: ",
"Final Answer: The answer is 99 ",
]
non_match_examples = ["Final answer: 42 is the answer", "Final Answer", "The final answer is: 100"]
final_answer_pattern = r"Final Answer\s*:\s*(.*)"
for example in match_examples:
match = re.match(final_answer_pattern, example)
assert match is not None
for example in non_match_examples:
match = re.match(final_answer_pattern, example)
assert match is None
@pytest.mark.integration