feat: Consider prompt_node's default_prompt_template in agent (#5095)

* consider prompt_node's default_prompt_template in agent

* make test a unit test via mocking

* updated docstring
This commit is contained in:
Julian Risch 2023-06-08 13:42:28 +02:00 committed by GitHub
parent 52e7a77595
commit d8a4f20379
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 5 deletions

View File

@ -235,10 +235,9 @@ class Agent:
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to :param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to
it in each iteration. it in each iteration.
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and :param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a used for generating thoughts and choosing tools to answer queries step-by-step. If it's not set, the PromptNode's
new template in a similar format. default template is used and if it's not set either, the Agent's default `zero-shot-react` template is used.
with `add_tool()` before running the Agent.
:param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name. :param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name.
You can also add tools with `add_tool()` before running the Agent. You can also add tools with `add_tool()` before running the Agent.
:param memory: A Memory instance that the Agent uses to store information between iterations. :param memory: A Memory instance that the Agent uses to store information between iterations.
@ -258,7 +257,7 @@ class Agent:
self.memory = memory or NoMemory() self.memory = memory or NoMemory()
self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token")) self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token"))
self.prompt_node = prompt_node self.prompt_node = prompt_node
prompt_template = prompt_template or "zero-shot-react" prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react"
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template) resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
if not resolved_prompt_template: if not resolved_prompt_template:
raise ValueError( raise ValueError(

View File

@ -2,6 +2,7 @@ import logging
import os import os
import re import re
from typing import Tuple from typing import Tuple
from unittest.mock import patch
from test.conftest import MockRetriever, MockPromptNode from test.conftest import MockRetriever, MockPromptNode
from unittest import mock from unittest import mock
@ -298,3 +299,19 @@ def test_invalid_agent_template():
a = Agent(prompt_node=pn, prompt_template=None) a = Agent(prompt_node=pn, prompt_template=None)
assert isinstance(a.prompt_template, PromptTemplate) assert isinstance(a.prompt_template, PromptTemplate)
assert a.prompt_template.name == "zero-shot-react" assert a.prompt_template.name == "zero-shot-react"
@pytest.mark.unit
@patch.object(PromptNode, "prompt")
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def test_default_template_order(mock_model, mock_prompt):
pn = PromptNode("abc")
a = Agent(prompt_node=pn)
assert a.prompt_template.name == "zero-shot-react"
pn.default_prompt_template = "language-detection"
a = Agent(prompt_node=pn)
assert a.prompt_template.name == "language-detection"
a = Agent(prompt_node=pn, prompt_template="translation")
assert a.prompt_template.name == "translation"