refactor: Make extraction of "Tool" and "Tool input" for Agent more robust and user-friendly (#4269)

* adjust [] in prompt template. Add error+docs for Tool name.

* fix test

* update error message
This commit is contained in:
Malte Pietsch 2023-02-28 20:01:34 +01:00 committed by GitHub
parent c3a38a59c0
commit 2a1d73e16d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 4 deletions

View File

@ -30,7 +30,7 @@ class Tool:
:param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates.
The name should be short, ideally one token, and a good description of what the tool can do, for example
"Calculator" or "Search".
"Calculator" or "Search". Use only letters (a-z, A-Z), digits (0-9) and underscores (_).".
:param pipeline_or_node: The pipeline or node to run when this tool is invoked by an Agent.
:param description: A description of what the tool is useful for. An Agent can use this description for the decision
when to use which tool. For example, a tool for calculations can be described by "useful for when you need to
@ -53,6 +53,10 @@ class Tool:
],
description: str,
):
if re.search(r"\W", name):
raise ValueError(
f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and underscores (_)."
)
self.name = name
self.pipeline_or_node = pipeline_or_node
self.description = description
@ -165,7 +169,12 @@ class Agent:
return self._format_answer(query=query, transcript=transcript, answer=final_answer)
tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0])
if tool_name is None or tool_input is None:
raise AgentError(f"Wrong output format. Transcript:\n{transcript}")
raise AgentError(
f"Could not identify the next tool or input for that tool from Agent's output. Adjust the Agent's param 'tool_pattern' or 'prompt_template'. \n"
f"# Agent's output: {preds[0]} \n"
f"# 'tool_pattern' to identify next tool: {self.tool_pattern} \n"
f"# Transcript:\n{transcript}"
)
result = self._run_tool(tool_name=tool_name, tool_input=tool_input, transcript=transcript, params=params)
observation = self._extract_observation(result)

View File

@ -704,7 +704,7 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]:
"Use the following format:\n\n"
"Question: the question to be answered\n"
"Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n"
"Tool: [$tool_names]\n"
"Tool: pick one of $tool_names \n"
"Tool Input: the input for the tool\n"
"Observation: the tool will respond with the result\n"
"...\n"

View File

@ -64,7 +64,9 @@ def test_agent_chooses_no_action():
description="useful for when you need to retrieve documents from your index",
)
)
with pytest.raises(AgentError, match=r"Wrong output format.*"):
with pytest.raises(
AgentError, match=r"Could not identify the next tool or input for that tool from Agent's output.*"
):
agent.run("How many letters does the name of the town where Christelle lives have?")