From 2a1d73e16d1036010d51d9906154230d562def48 Mon Sep 17 00:00:00 2001 From: Malte Pietsch Date: Tue, 28 Feb 2023 20:01:34 +0100 Subject: [PATCH] refactor: Make extraction of "Tool" and "Tool input" for Agent more robust and user-friendly (#4269) * adjust [] in prompt template. Add error+docs for Tool name. * fix test * update error message --- haystack/agents/base.py | 13 +++++++++++-- haystack/nodes/prompt/prompt_node.py | 2 +- test/agents/test_agent.py | 4 +++- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index aaef4accb..3f8e6d35a 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -30,7 +30,7 @@ class Tool: :param name: The name of the tool. The Agent uses this name to refer to the tool in the text the Agent generates. The name should be short, ideally one token, and a good description of what the tool can do, for example - "Calculator" or "Search". + "Calculator" or "Search". Use only letters (a-z, A-Z), digits (0-9) and underscores (_).". :param pipeline_or_node: The pipeline or node to run when this tool is invoked by an Agent. :param description: A description of what the tool is useful for. An Agent can use this description for the decision when to use which tool. For example, a tool for calculations can be described by "useful for when you need to @@ -53,6 +53,10 @@ class Tool: ], description: str, ): + if re.search(r"\W", name): + raise ValueError( + f"Invalid name supplied for tool: '{name}'. Use only letters (a-z, A-Z), digits (0-9) and underscores (_)." + ) self.name = name self.pipeline_or_node = pipeline_or_node self.description = description @@ -165,7 +169,12 @@ class Agent: return self._format_answer(query=query, transcript=transcript, answer=final_answer) tool_name, tool_input = self._extract_tool_name_and_tool_input(pred=preds[0]) if tool_name is None or tool_input is None: - raise AgentError(f"Wrong output format. Transcript:\n{transcript}") + raise AgentError( + f"Could not identify the next tool or input for that tool from Agent's output. Adjust the Agent's param 'tool_pattern' or 'prompt_template'. \n" + f"# Agent's output: {preds[0]} \n" + f"# 'tool_pattern' to identify next tool: {self.tool_pattern} \n" + f"# Transcript:\n{transcript}" + ) result = self._run_tool(tool_name=tool_name, tool_input=tool_input, transcript=transcript, params=params) observation = self._extract_observation(result) diff --git a/haystack/nodes/prompt/prompt_node.py b/haystack/nodes/prompt/prompt_node.py index 8c86ba4f2..48e935b94 100644 --- a/haystack/nodes/prompt/prompt_node.py +++ b/haystack/nodes/prompt/prompt_node.py @@ -704,7 +704,7 @@ def get_predefined_prompt_templates() -> List[PromptTemplate]: "Use the following format:\n\n" "Question: the question to be answered\n" "Thought: Reason if you have the final answer. If yes, answer the question. If not, find out the missing information needed to answer it.\n" - "Tool: [$tool_names]\n" + "Tool: pick one of $tool_names \n" "Tool Input: the input for the tool\n" "Observation: the tool will respond with the result\n" "...\n" diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 1191a547f..edfbf1b0d 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -64,7 +64,9 @@ def test_agent_chooses_no_action(): description="useful for when you need to retrieve documents from your index", ) ) - with pytest.raises(AgentError, match=r"Wrong output format.*"): + with pytest.raises( + AgentError, match=r"Could not identify the next tool or input for that tool from Agent's output.*" + ): agent.run("How many letters does the name of the town where Christelle lives have?")