mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-11-02 10:49:30 +00:00
feat: Consider prompt_node's default_prompt_template in agent (#5095)
* consider prompt_node's default_prompt_template in agent * make test a unit test via mocking * updated docstring
This commit is contained in:
parent
52e7a77595
commit
d8a4f20379
@ -235,10 +235,9 @@ class Agent:
|
||||
|
||||
:param prompt_node: The PromptNode that the Agent uses to decide which tool to use and what input to provide to
|
||||
it in each iteration.
|
||||
:param prompt_template: The name of a PromptTemplate for the PromptNode. It's used for generating thoughts and
|
||||
choosing tools to answer queries step-by-step. You can use the default `zero-shot-react` template or create a
|
||||
new template in a similar format.
|
||||
with `add_tool()` before running the Agent.
|
||||
:param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's
|
||||
used for generating thoughts and choosing tools to answer queries step-by-step. If it's not set, the PromptNode's
|
||||
default template is used and if it's not set either, the Agent's default `zero-shot-react` template is used.
|
||||
:param tools_manager: A ToolsManager instance that the Agent uses to run tools. Each tool must have a unique name.
|
||||
You can also add tools with `add_tool()` before running the Agent.
|
||||
:param memory: A Memory instance that the Agent uses to store information between iterations.
|
||||
@ -258,7 +257,7 @@ class Agent:
|
||||
self.memory = memory or NoMemory()
|
||||
self.callback_manager = Events(("on_agent_start", "on_agent_step", "on_agent_finish", "on_new_token"))
|
||||
self.prompt_node = prompt_node
|
||||
prompt_template = prompt_template or "zero-shot-react"
|
||||
prompt_template = prompt_template or prompt_node.default_prompt_template or "zero-shot-react"
|
||||
resolved_prompt_template = prompt_node.get_prompt_template(prompt_template)
|
||||
if not resolved_prompt_template:
|
||||
raise ValueError(
|
||||
|
||||
@ -2,6 +2,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
from test.conftest import MockRetriever, MockPromptNode
|
||||
from unittest import mock
|
||||
@ -298,3 +299,19 @@ def test_invalid_agent_template():
|
||||
a = Agent(prompt_node=pn, prompt_template=None)
|
||||
assert isinstance(a.prompt_template, PromptTemplate)
|
||||
assert a.prompt_template.name == "zero-shot-react"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
@patch.object(PromptNode, "prompt")
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def test_default_template_order(mock_model, mock_prompt):
|
||||
pn = PromptNode("abc")
|
||||
a = Agent(prompt_node=pn)
|
||||
assert a.prompt_template.name == "zero-shot-react"
|
||||
|
||||
pn.default_prompt_template = "language-detection"
|
||||
a = Agent(prompt_node=pn)
|
||||
assert a.prompt_template.name == "language-detection"
|
||||
|
||||
a = Agent(prompt_node=pn, prompt_template="translation")
|
||||
assert a.prompt_template.name == "translation"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user