mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-10-24 14:28:42 +00:00
refactor: Add AgentStep (#4431)
This commit is contained in:
parent
4d19bd13a5
commit
3272e2b9fe
@ -1,2 +1,3 @@
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
from haystack.agents.base import Agent
|
||||
from haystack.agents.base import Tool
|
||||
|
||||
157
haystack/agents/agent_step.py
Normal file
157
haystack/agents/agent_step.py
Normal file
@ -0,0 +1,157 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Dict, Tuple, Any
|
||||
|
||||
from haystack import Answer
|
||||
from haystack.errors import AgentError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentStep:
|
||||
"""
|
||||
The AgentStep class represents a single step in the execution of an agent.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_step: int = 1,
|
||||
max_steps: int = 10,
|
||||
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
|
||||
prompt_node_response: str = "",
|
||||
transcript: str = "",
|
||||
):
|
||||
"""
|
||||
:param current_step: The current step in the execution of the agent.
|
||||
:param max_steps: The maximum number of steps the agent can execute.
|
||||
:param final_answer_pattern: The regex pattern to extract the final answer from the PromptNode response.
|
||||
:param prompt_node_response: The PromptNode response received.
|
||||
:param transcript: The full Agent execution transcript based on the Agent's initial prompt template and the
|
||||
text it generated during execution up to this step. The transcript is used to generate the next prompt.
|
||||
"""
|
||||
self.current_step = current_step
|
||||
self.max_steps = max_steps
|
||||
self.final_answer_pattern = final_answer_pattern
|
||||
self.prompt_node_response = prompt_node_response
|
||||
self.transcript = transcript
|
||||
|
||||
def prepare_prompt(self):
|
||||
"""
|
||||
Prepares the prompt for the next step.
|
||||
"""
|
||||
return self.transcript
|
||||
|
||||
def create_next_step(self, prompt_node_response: Any) -> AgentStep:
|
||||
"""
|
||||
Creates the next agent step based on the current step and the PromptNode response.
|
||||
:param prompt_node_response: The PromptNode response received.
|
||||
"""
|
||||
if not isinstance(prompt_node_response, list):
|
||||
raise AgentError(
|
||||
f"Agent output must be a list of str, but {prompt_node_response} received. "
|
||||
f"Transcript:\n{self.transcript}"
|
||||
)
|
||||
|
||||
if not prompt_node_response:
|
||||
raise AgentError(
|
||||
f"Agent output must be a non empty list of str, but {prompt_node_response} received. "
|
||||
f"Transcript:\n{self.transcript}"
|
||||
)
|
||||
|
||||
return AgentStep(
|
||||
current_step=self.current_step + 1,
|
||||
max_steps=self.max_steps,
|
||||
final_answer_pattern=self.final_answer_pattern,
|
||||
prompt_node_response=prompt_node_response[0],
|
||||
transcript=self.transcript,
|
||||
)
|
||||
|
||||
def extract_tool_name_and_tool_input(self, tool_pattern: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Parse the tool name and the tool input from the PromptNode response.
|
||||
:param tool_pattern: The regex pattern to extract the tool name and the tool input from the PromptNode response.
|
||||
:return: A tuple containing the tool name and the tool input.
|
||||
"""
|
||||
tool_match = re.search(tool_pattern, self.prompt_node_response)
|
||||
if tool_match:
|
||||
tool_name = tool_match.group(1)
|
||||
tool_input = tool_match.group(3)
|
||||
return tool_name.strip('" []\n').strip(), tool_input.strip('" \n')
|
||||
return None, None
|
||||
|
||||
def final_answer(self, query: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Formats an answer as a dict containing `query` and `answers` similar to the output of a Pipeline.
|
||||
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
|
||||
|
||||
:param query: The search query
|
||||
"""
|
||||
answer: Dict[str, Any] = {
|
||||
"query": query,
|
||||
"answers": [Answer(answer="", type="generative")],
|
||||
"transcript": self.transcript,
|
||||
}
|
||||
if self.current_step >= self.max_steps:
|
||||
logger.warning(
|
||||
"Maximum number of iterations (%s) reached for query (%s). Increase max_steps "
|
||||
"or no answer can be provided for this query.",
|
||||
self.max_steps,
|
||||
query,
|
||||
)
|
||||
else:
|
||||
final_answer = self.extract_final_answer()
|
||||
if not final_answer:
|
||||
logger.warning(
|
||||
"Final answer pattern (%s) not found in PromptNode response (%s).",
|
||||
self.final_answer_pattern,
|
||||
self.prompt_node_response,
|
||||
)
|
||||
else:
|
||||
answer = {
|
||||
"query": query,
|
||||
"answers": [Answer(answer=final_answer, type="generative")],
|
||||
"transcript": self.transcript,
|
||||
}
|
||||
return answer
|
||||
|
||||
def extract_final_answer(self) -> Optional[str]:
|
||||
"""
|
||||
Parse the final answer from the PromptNode response.
|
||||
:return: The final answer.
|
||||
"""
|
||||
if not self.is_last():
|
||||
raise AgentError("Cannot extract final answer from non terminal step.")
|
||||
|
||||
final_answer_match = re.search(self.final_answer_pattern, self.prompt_node_response)
|
||||
if final_answer_match:
|
||||
final_answer = final_answer_match.group(1)
|
||||
return final_answer.strip('" ')
|
||||
return None
|
||||
|
||||
def is_final_answer_pattern_found(self) -> bool:
|
||||
"""
|
||||
Check if the final answer pattern was found in PromptNode response.
|
||||
:return: True if the final answer pattern was found in PromptNode response, False otherwise.
|
||||
"""
|
||||
return bool(re.search(self.final_answer_pattern, self.prompt_node_response))
|
||||
|
||||
def is_last(self) -> bool:
|
||||
"""
|
||||
Check if this is the last step of the Agent.
|
||||
:return: True if this is the last step of the Agent, False otherwise.
|
||||
"""
|
||||
return self.is_final_answer_pattern_found() or self.current_step >= self.max_steps
|
||||
|
||||
def completed(self, observation: Optional[str]):
|
||||
"""
|
||||
Update the transcript with the observation
|
||||
:param observation: received observation from the Agent environment.
|
||||
"""
|
||||
self.transcript += (
|
||||
f"{self.prompt_node_response}\nObservation: {observation}\nThought:"
|
||||
if observation
|
||||
else self.prompt_node_response
|
||||
)
|
||||
@ -2,9 +2,10 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional, Union, Dict, Tuple, Any
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from haystack import Pipeline, BaseComponent, Answer, Document
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||
from haystack.pipelines import (
|
||||
@ -24,10 +25,12 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class Tool:
|
||||
"""
|
||||
Agent uses tools to find the best answer. A tool is a pipeline or a node. When you add a tool to an Agent, the Agent can
|
||||
invoke the underlying pipeline or node to answer questions.
|
||||
Agent uses tools to find the best answer. A tool is a pipeline or a node. When you add a tool to an Agent, the Agent
|
||||
can invoke the underlying pipeline or node to answer questions.
|
||||
|
||||
You must provide a name and a description for each tool. The name should be short and should indicate what the tool can do. The description should explain what the tool is useful for. The Agent uses the description to decide when to use a tool, so the wording you use is important.
|
||||
You must provide a name and a description for each tool. The name should be short and should indicate what the tool
|
||||
can do. The description should explain what the tool is useful for. The Agent uses the description to decide when
|
||||
to use a tool, so the wording you use is important.
|
||||
|
||||
:param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates.
|
||||
The name should be short, ideally one token, and a good description of what the tool can do, for example:
|
||||
@ -58,7 +61,8 @@ class Tool:
|
||||
):
|
||||
if re.search(r"\W", name):
|
||||
raise ValueError(
|
||||
f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and underscores (_)."
|
||||
f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and "
|
||||
f"underscores (_)."
|
||||
)
|
||||
self.name = name
|
||||
self.pipeline_or_node = pipeline_or_node
|
||||
@ -101,14 +105,16 @@ class Tool:
|
||||
class Agent:
|
||||
"""
|
||||
An Agent answers queries using the tools you give to it. The tools are pipelines or nodes. The Agent uses a large
|
||||
language model (LLM) through the PromptNode you initialize it with. To answer a query, the Agent follows this sequence:
|
||||
language model (LLM) through the PromptNode you initialize it with. To answer a query, the Agent follows this
|
||||
sequence:
|
||||
|
||||
1. It generates a thought based on the query.
|
||||
2. It decides which tool to use.
|
||||
3. It generates the input for the tool.
|
||||
4. Based on the output it gets from the tool, the Agent can either stop if it now knows the answer or repeat the
|
||||
process of 1) generate thought, 2) choose tool, 3) generate input.
|
||||
|
||||
Agents are useful for questions containing multiple subquestions that can be answered step-by-step (Multihop QA)
|
||||
Agents are useful for questions containing multiple sub questions that can be answered step-by-step (Multi-hop QA)
|
||||
using multiple pipelines and nodes as tools.
|
||||
"""
|
||||
|
||||
@ -117,19 +123,25 @@ class Agent:
|
||||
prompt_node: PromptNode,
|
||||
prompt_template: Union[str, PromptTemplate] = "zero-shot-react",
|
||||
tools: Optional[List[Tool]] = None,
|
||||
max_iterations: int = 5,
|
||||
max_steps: int = 8,
|
||||
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*',
|
||||
final_answer_pattern: str = r"Final Answer:\s*(\w+)\s*",
|
||||
final_answer_pattern: str = r"Final Answer\s*:\s*(.*)",
|
||||
):
|
||||
"""
|
||||
Creates an Agent instance.
|
||||
|
||||
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration.
|
||||
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a new template in a similar format.
|
||||
:param tools: A list of tools the Agent can run. If you don't specify any tools here, you must add them with `add_tool()` before running the Agent.
|
||||
:param max_iterations: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
|
||||
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer. The default is 5.
|
||||
:param tool_pattern: A regular expression to extract the name of the tool and the corresponding input from the text the Agent generated.
|
||||
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to
|
||||
it in each iteration.
|
||||
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and
|
||||
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
|
||||
new template in a similar format.
|
||||
:param tools: A list of tools the Agent can run. If you don't specify any tools here, you must add them
|
||||
with `add_tool()` before running the Agent.
|
||||
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer.
|
||||
Set it to at least 2, so that the Agent can run one a tool once and then infer it knows the final answer.
|
||||
The default is 5.
|
||||
:param tool_pattern: A regular expression to extract the name of the tool and the corresponding input from the
|
||||
text the Agent generated.
|
||||
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
|
||||
"""
|
||||
self.prompt_node = prompt_node
|
||||
@ -141,7 +153,7 @@ class Agent:
|
||||
self.tool_names_with_descriptions = "\n".join(
|
||||
[f"{tool.name}: {tool.description}" for tool in self.tools.values()]
|
||||
)
|
||||
self.max_iterations = max_iterations
|
||||
self.max_steps = max_steps
|
||||
self.tool_pattern = tool_pattern
|
||||
self.final_answer_pattern = final_answer_pattern
|
||||
send_custom_event(event=f"{type(self).__name__} initialized")
|
||||
@ -150,7 +162,8 @@ class Agent:
|
||||
"""
|
||||
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
|
||||
|
||||
:param tool: The tool to add to the Agent. Any previously added tool with the same name will be overwritten. Example:
|
||||
:param tool: The tool to add to the Agent. Any previously added tool with the same name will be overwritten.
|
||||
Example:
|
||||
`agent.add_tool(
|
||||
Tool(
|
||||
name="Calculator",
|
||||
@ -174,15 +187,16 @@ class Agent:
|
||||
return tool_name in self.tools
|
||||
|
||||
def run(
|
||||
self, query: str, max_iterations: Optional[int] = None, params: Optional[dict] = None
|
||||
self, query: str, max_steps: Optional[int] = None, params: Optional[dict] = None
|
||||
) -> Dict[str, Union[str, List[Answer]]]:
|
||||
"""
|
||||
Runs the Agent given a query and optional parameters to pass on to the tools used. The result is in the
|
||||
same format as a pipeline's result: a dictionary with a key `answers` containing a list of answers.
|
||||
|
||||
:param query: The search query.
|
||||
:param max_iterations: The number of times the Agent can run a tool +1 to infer it knows the final answer.
|
||||
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the final answer.
|
||||
:param query: The search query
|
||||
:param max_steps: The number of times the Agent can run a tool +1 to infer it knows the final answer.
|
||||
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the
|
||||
final answer.
|
||||
:param params: A dictionary of parameters you want to pass to the tools that are pipelines.
|
||||
To pass a parameter to all nodes in those pipelines, use the format: `{"top_k": 10}`.
|
||||
To pass a parameter to targeted nodes in those pipelines, use the format:
|
||||
@ -191,49 +205,57 @@ class Agent:
|
||||
"""
|
||||
if not self.tools:
|
||||
raise AgentError(
|
||||
"An Agent needs tools to run. Add at least one tool using `add_tool()` or set the parameter `tools` when initializing the Agent."
|
||||
"An Agent needs tools to run. Add at least one tool using `add_tool()` or set the parameter `tools` "
|
||||
"when initializing the Agent."
|
||||
)
|
||||
if max_iterations is None:
|
||||
max_iterations = self.max_iterations
|
||||
if max_iterations < 2:
|
||||
if max_steps is None:
|
||||
max_steps = self.max_steps
|
||||
if max_steps < 2:
|
||||
raise AgentError(
|
||||
f"max_iterations must be at least 2 to let the Agent use a tool once and then infer it knows the final answer. It was set to {max_iterations}."
|
||||
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
|
||||
f"answer. It was set to {max_steps}."
|
||||
)
|
||||
|
||||
agent_step = self._create_first_step(query, max_steps)
|
||||
while not agent_step.is_last():
|
||||
agent_step = self._step(agent_step, params)
|
||||
|
||||
return agent_step.final_answer(query=query)
|
||||
|
||||
def _create_first_step(self, query: str, max_steps: int = 10):
|
||||
transcript = self._get_initial_transcript(query=query)
|
||||
# Generate a thought with a plan what to do, choose a tool, generate input for it, and run it.
|
||||
# Repeat this until the final answer is found or the maximum number of iterations is reached.
|
||||
for _ in range(max_iterations):
|
||||
preds = self.prompt_node(transcript)
|
||||
if not preds:
|
||||
raise AgentError(f"The Agent generated no output. Transcript:\n{transcript}")
|
||||
|
||||
# Try to extract final answer or tool name and input from the generated text and update the transcript
|
||||
final_answer = self._extract_final_answer(pred=preds[0])
|
||||
if final_answer is not None:
|
||||
transcript += preds[0]
|
||||
return self._format_answer(query=query, transcript=transcript, answer=final_answer)
|
||||
tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0])
|
||||
observation = self._run_tool(tool_name, tool_input, transcript + preds[0], params)
|
||||
transcript += f"{preds[0]}\nObservation: {observation}\nThought:"
|
||||
|
||||
logger.warning(
|
||||
"The Agent reached the maximum number of iterations (%s) for query (%s). Increase the max_iterations parameter "
|
||||
"or the Agent won't be able to provide an answer to this query.",
|
||||
max_iterations,
|
||||
query,
|
||||
return AgentStep(
|
||||
current_step=1,
|
||||
max_steps=max_steps,
|
||||
final_answer_pattern=self.final_answer_pattern,
|
||||
prompt_node_response="", # no LLM response for the first step
|
||||
transcript=transcript,
|
||||
)
|
||||
return self._format_answer(query=query, transcript=transcript, answer="")
|
||||
|
||||
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
|
||||
# plan next step using the LLM
|
||||
prompt_node_response = self.prompt_node(current_step.prepare_prompt())
|
||||
|
||||
# from the LLM response, create the next step
|
||||
next_step = current_step.create_next_step(prompt_node_response)
|
||||
|
||||
# run the tool selected by the LLM
|
||||
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
|
||||
|
||||
# update the next step with the observation
|
||||
next_step.completed(observation)
|
||||
return next_step
|
||||
|
||||
def run_batch(
|
||||
self, queries: List[str], max_iterations: Optional[int] = None, params: Optional[dict] = None
|
||||
self, queries: List[str], max_steps: Optional[int] = None, params: Optional[dict] = None
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Runs the Agent in a batch mode.
|
||||
|
||||
:param queries: List of search queries.
|
||||
:param max_iterations: The number of times the Agent can run a tool +1 to infer it knows the final answer.
|
||||
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows the final answer.
|
||||
:param max_steps: The number of times the Agent can run a tool +1 to infer it knows the final answer.
|
||||
If you want to set it, make it at least 2 so that the Agent can run a tool once and then infer it knows
|
||||
the final answer.
|
||||
:param params: A dictionary of parameters you want to pass to the tools that are pipelines.
|
||||
To pass a parameter to all nodes in those pipelines, use the format: `{"top_k": 10}`.
|
||||
To pass a parameter to targeted nodes in those pipelines, use the format:
|
||||
@ -242,67 +264,30 @@ class Agent:
|
||||
"""
|
||||
results: Dict = {"queries": [], "answers": [], "transcripts": []}
|
||||
for query in queries:
|
||||
result = self.run(query=query, max_iterations=max_iterations, params=params)
|
||||
result = self.run(query=query, max_steps=max_steps, params=params)
|
||||
results["queries"].append(result["query"])
|
||||
results["answers"].append(result["answers"])
|
||||
results["transcripts"].append(result["transcript"])
|
||||
|
||||
return results
|
||||
|
||||
def _run_tool(
|
||||
self, tool_name: Optional[str], tool_input: Optional[str], transcript: str, params: Optional[dict] = None
|
||||
) -> str:
|
||||
def _run_tool(self, next_step: AgentStep, params: Optional[Dict[str, Any]] = None) -> str:
|
||||
tool_name, tool_input = next_step.extract_tool_name_and_tool_input(self.tool_pattern)
|
||||
if tool_name is None or tool_input is None:
|
||||
raise AgentError(
|
||||
f"Could not identify the next tool or input for that tool from Agent's output. "
|
||||
f"Adjust the Agent's param 'tool_pattern' or 'prompt_template'. \n"
|
||||
f"# 'tool_pattern' to identify next tool: {self.tool_pattern} \n"
|
||||
f"# Transcript:\n{transcript}"
|
||||
f"# Agent Step:\n{next_step}"
|
||||
)
|
||||
if not self.has_tool(tool_name):
|
||||
raise AgentError(
|
||||
f"The tool {tool_name} wasn't added to the Agent tools: {self.tools.keys()}."
|
||||
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
|
||||
f"Transcript:\n{transcript}"
|
||||
f"Agent Step::\n{next_step}"
|
||||
)
|
||||
return self.tools[tool_name].run(tool_input, params)
|
||||
|
||||
def _extract_tool_name_and_tool_input(self, pred: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Parse the tool name and the tool input from the prediction output of the Agent's PromptNode.
|
||||
|
||||
:param pred: Prediction output of the Agent's PromptNode from which to parse the tool and tool input.
|
||||
"""
|
||||
tool_match = re.search(self.tool_pattern, pred)
|
||||
if tool_match:
|
||||
tool_name = tool_match.group(1)
|
||||
tool_input = tool_match.group(3)
|
||||
return tool_name.strip('" []').strip(), tool_input.strip('" ')
|
||||
return None, None
|
||||
|
||||
def _extract_final_answer(self, pred: str) -> Optional[str]:
|
||||
"""
|
||||
Parse the final answer from the prediction output of the Agent's PromptNode.
|
||||
|
||||
:param pred: Prediction output of the Agent's PromptNode from which to parse the final answer.
|
||||
"""
|
||||
final_answer_match = re.search(self.final_answer_pattern, pred)
|
||||
if final_answer_match:
|
||||
final_answer = final_answer_match.group(1)
|
||||
return final_answer.strip('" ')
|
||||
return None
|
||||
|
||||
def _format_answer(self, query: str, answer: str, transcript: str) -> Dict[str, Union[str, List[Answer]]]:
|
||||
"""
|
||||
Formats an answer as a dict containing `query` and `answers`, similar to the output of a Pipeline.
|
||||
The full transcript based on the Agent's initial prompt template and the text it generated during execution.
|
||||
|
||||
:param query: The search query.
|
||||
:param answer: The final answer the Agent returned. An empty string corresponds to no answer.
|
||||
:param transcript: The text the Agent generated and the initial, filled template for debug purposes.
|
||||
"""
|
||||
return {"query": query, "answers": [Answer(answer=answer, type="generative")], "transcript": transcript}
|
||||
|
||||
def _get_initial_transcript(self, query: str):
|
||||
"""
|
||||
Fills the Agent's PromptTemplate with the query, tool names, and descriptions.
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from haystack import BaseComponent, Answer, Document
|
||||
from haystack.agents import Agent
|
||||
from haystack.agents import Agent, AgentStep
|
||||
from haystack.agents.base import Tool
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
||||
@ -71,9 +72,9 @@ def test_agent_chooses_no_action():
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_max_iterations(caplog, monkeypatch):
|
||||
# Run an Agent and stop because max_iterations is reached
|
||||
agent = Agent(prompt_node=MockPromptNode(), max_iterations=3)
|
||||
def test_max_steps(caplog, monkeypatch):
|
||||
# Run an Agent and stop because max_steps is reached
|
||||
agent = Agent(prompt_node=MockPromptNode(), max_steps=3)
|
||||
retriever = MockRetriever()
|
||||
agent.add_tool(
|
||||
Tool(
|
||||
@ -88,17 +89,17 @@ def test_max_iterations(caplog, monkeypatch):
|
||||
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
|
||||
return "Retriever", ""
|
||||
|
||||
monkeypatch.setattr(Agent, "_extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
|
||||
monkeypatch.setattr(AgentStep, "extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
|
||||
|
||||
# Using max_iterations as specified in the Agent's init method
|
||||
# Using max_steps as specified in the Agent's init method
|
||||
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
||||
result = agent.run("Where does Christelle live?")
|
||||
assert result["answers"] == [Answer(answer="", type="generative")]
|
||||
assert "maximum number of iterations (3)" in caplog.text.lower()
|
||||
|
||||
# Setting max_iterations in the Agent's run method
|
||||
# Setting max_steps in the Agent's run method
|
||||
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
||||
result = agent.run("Where does Christelle live?", max_iterations=2)
|
||||
result = agent.run("Where does Christelle live?", max_steps=2)
|
||||
assert result["answers"] == [Answer(answer="", type="generative")]
|
||||
assert "maximum number of iterations (2)" in caplog.text.lower()
|
||||
|
||||
@ -115,35 +116,81 @@ def test_run_tool():
|
||||
output_variable="documents",
|
||||
)
|
||||
)
|
||||
result = agent._run_tool(tool_name="Retriever", tool_input="", transcript="")
|
||||
pn_response = "need to find out what city he was born.\nTool: Retriever\nTool Input: Where was Jeremy McKinnon born"
|
||||
|
||||
step = AgentStep(prompt_node_response=pn_response)
|
||||
result = agent._run_tool(step)
|
||||
assert result == "[]" # empty list of documents
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extract_tool_name_and_tool_input():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
tool_pattern: str = r'Tool:\s*(\w+)\s*Tool Input:\s*("?)([^"\n]+)\2\s*'
|
||||
pn_response = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
|
||||
|
||||
pred = "need to find out what city he was born.\nTool: Search\nTool Input: Where was Jeremy McKinnon born"
|
||||
tool_name, tool_input = agent._extract_tool_name_and_tool_input(pred)
|
||||
step = AgentStep(prompt_node_response=pn_response)
|
||||
tool_name, tool_input = step.extract_tool_name_and_tool_input(tool_pattern=tool_pattern)
|
||||
assert tool_name == "Search" and tool_input == "Where was Jeremy McKinnon born"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_extract_final_answer():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
match_examples = [
|
||||
"have the final answer to the question.\nFinal Answer: Florida",
|
||||
"Final Answer: 42 is the answer",
|
||||
"Final Answer: 1234",
|
||||
"Final Answer: Answer",
|
||||
"Final Answer: This list: one and two and three",
|
||||
"Final Answer:42",
|
||||
"Final Answer: ",
|
||||
"Final Answer: The answer is 99 ",
|
||||
]
|
||||
expected_answers = [
|
||||
"Florida",
|
||||
"42 is the answer",
|
||||
"1234",
|
||||
"Answer",
|
||||
"This list: one and two and three",
|
||||
"42",
|
||||
"",
|
||||
"The answer is 99",
|
||||
]
|
||||
|
||||
pred = "have the final answer to the question.\nFinal Answer: Florida"
|
||||
final_answer = agent._extract_final_answer(pred)
|
||||
assert final_answer == "Florida"
|
||||
for example, expected_answer in zip(match_examples, expected_answers):
|
||||
agent_step = AgentStep(prompt_node_response=example)
|
||||
final_answer = agent_step.extract_final_answer()
|
||||
assert final_answer == expected_answer
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_format_answer():
|
||||
agent = Agent(prompt_node=MockPromptNode())
|
||||
formatted_answer = agent._format_answer(query="query", answer="answer", transcript="transcript")
|
||||
step = AgentStep(prompt_node_response="have the final answer to the question.\nFinal Answer: Florida")
|
||||
formatted_answer = step.final_answer(query="query")
|
||||
assert formatted_answer["query"] == "query"
|
||||
assert formatted_answer["answers"] == [Answer(answer="answer", type="generative")]
|
||||
assert formatted_answer["transcript"] == "transcript"
|
||||
assert formatted_answer["answers"] == [Answer(answer="Florida", type="generative")]
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_final_answer_regex():
|
||||
match_examples = [
|
||||
"Final Answer: 42 is the answer",
|
||||
"Final Answer: 1234",
|
||||
"Final Answer: Answer",
|
||||
"Final Answer: This list: one and two and three",
|
||||
"Final Answer:42",
|
||||
"Final Answer: ",
|
||||
"Final Answer: The answer is 99 ",
|
||||
]
|
||||
|
||||
non_match_examples = ["Final answer: 42 is the answer", "Final Answer", "The final answer is: 100"]
|
||||
final_answer_pattern = r"Final Answer\s*:\s*(.*)"
|
||||
for example in match_examples:
|
||||
match = re.match(final_answer_pattern, example)
|
||||
assert match is not None
|
||||
|
||||
for example in non_match_examples:
|
||||
match = re.match(final_answer_pattern, example)
|
||||
assert match is None
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user