diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 960e610a2..ebf54d45d 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -235,10 +235,9 @@ class Agent: :param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to it in each iteration. - :param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and - choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a - new template in a similar format. - with `add_tool()` before running the Agent. + :param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's + used for generating thoughts and choosing tools to answer queries step-by-step. If it's not set, the PromptNode's + default template is used and if it's not set either, the Agent's default `zero-shot-react` template is used. :param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name. You can also add tools with `add_tool()` before running the Agent. :param memory: A Memory instance that the Agent uses to store information between iterations. @@ -258,7 +257,7 @@ class Agent: self.memory = memory or NoMemory() self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token")) self.prompt_node = prompt_node - prompt_template = prompt_template or "zero-shot-react" + prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react" resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) if not resolved_prompt_template: raise ValueError( diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index a88da65c5..346e47ca0 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -2,6 +2,7 @@ import logging import os import re from typing import Tuple +from unittest.mock import patch from test.conftest import MockRetriever, MockPromptNode from unittest import mock @@ -298,3 +299,19 @@ def test_invalid_agent_template(): a = Agent(prompt_node=pn, prompt_template=None) assert isinstance(a.prompt_template, PromptTemplate) assert a.prompt_template.name == "zero-shot-react" + + +@pytest.mark.unit +@patch.object(PromptNode, "prompt") +@patch("haystack.nodes.prompt.prompt_node.PromptModel") +def test_default_template_order(mock_model, mock_prompt): + pn = PromptNode("abc") + a = Agent(prompt_node=pn) + assert a.prompt_template.name == "zero-shot-react" + + pn.default_prompt_template = "language-detection" + a = Agent(prompt_node=pn) + assert a.prompt_template.name == "language-detection" + + a = Agent(prompt_node=pn, prompt_template="translation") + assert a.prompt_template.name == "translation"