feat: Consider prompt_node's default_prompt_template in agent (#5095)

* consider prompt_node's default_prompt_template in agent

* make test a unit test via mocking

* updated docstring
This commit is contained in:
Julian Risch 2023-06-08 13:42:28 +02:00 committed by GitHub
parent 52e7a77595
commit d8a4f20379
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 5 deletions

View File

@ -235,10 +235,9 @@ class Agent:
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to
it in each iteration.
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
new template in a similar format.
with `add_tool()` before running the Agent.
:param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's
used for generating thoughts and choosing tools to answer queries step-by-step. If it's not set, the PromptNode's
default template is used and if it's not set either, the Agent's default `zero-shot-react` template is used.
:param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name.
You can also add tools with `add_tool()` before running the Agent.
:param memory: A Memory instance that the Agent uses to store information between iterations.
@ -258,7 +257,7 @@ class Agent:
self.memory = memory or NoMemory()
self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token"))
self.prompt_node = prompt_node
prompt_template = prompt_template or "zero-shot-react"
prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react"
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
if not resolved_prompt_template:
raise ValueError(

View File

@ -2,6 +2,7 @@ import logging
import os
import re
from typing import Tuple
from unittest.mock import patch
from test.conftest import MockRetriever, MockPromptNode
from unittest import mock
@ -298,3 +299,19 @@ def test_invalid_agent_template():
a = Agent(prompt_node=pn, prompt_template=None)
assert isinstance(a.prompt_template, PromptTemplate)
assert a.prompt_template.name == "zero-shot-react"
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_default_template_order(mock_model, mock_prompt):
pn = PromptNode("abc")
a = Agent(prompt_node=pn)
assert a.prompt_template.name == "zero-shot-react"
pn.default_prompt_template = "language-detection"
a = Agent(prompt_node=pn)
assert a.prompt_template.name == "language-detection"
a = Agent(prompt_node=pn, prompt_template="translation")
assert a.prompt_template.name == "translation"