mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-15 11:11:55 +00:00
fix: Check Agent's prompt template variables and prompt resolver parameters are aligned (#5163)
* Check Agent's prompt template parameters and prompt resolver parameters are aligned * Lower the logger warning * Automatically append transcript if needed * Amend flaky test
This commit is contained in:
parent
6a1b6b1ae3
commit
089187ac8b
@ -396,16 +396,14 @@ class Agent:
|
|||||||
# first resolve prompt template params
|
# first resolve prompt template params
|
||||||
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)
|
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)
|
||||||
|
|
||||||
# if prompt node has no default prompt template, use agent's prompt template
|
# check for template parameters mismatch
|
||||||
if self.prompt_node.default_prompt_template is None:
|
self.check_prompt_template(template_params)
|
||||||
prepared_prompt = next(self.prompt_template.fill(**template_params))
|
|
||||||
prompt_node_response = self.prompt_node(
|
# invoke via prompt node
|
||||||
prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager)
|
prompt_node_response = self.prompt_node.prompt(
|
||||||
)
|
prompt_template=self.prompt_template,
|
||||||
# otherwise, if prompt node has default prompt template, use it
|
stream_handler=AgentTokenStreamingHandler(self.callback_manager),
|
||||||
else:
|
**template_params,
|
||||||
prompt_node_response = self.prompt_node(
|
|
||||||
stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params
|
|
||||||
)
|
)
|
||||||
return prompt_node_response
|
return prompt_node_response
|
||||||
|
|
||||||
@ -422,3 +420,35 @@ class Agent:
|
|||||||
return {
|
return {
|
||||||
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
|
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def check_prompt_template(self, template_params: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Verifies that the Agent's prompt template is adequately populated with the correct parameters
|
||||||
|
provided by the prompt parameter resolver.
|
||||||
|
|
||||||
|
If template_params contains a parameter that is not specified in the prompt template, a warning is logged
|
||||||
|
at DEBUG level. Sometimes the prompt parameter resolver may provide additional parameters that are not
|
||||||
|
used by the prompt template. However, if the prompt parameter resolver provides a 'transcript'
|
||||||
|
parameter that is not used in the prompt template, an error is logged.
|
||||||
|
|
||||||
|
:param template_params: The parameters provided by the prompt parameter resolver.
|
||||||
|
|
||||||
|
"""
|
||||||
|
unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params)
|
||||||
|
|
||||||
|
if "transcript" in unused_params:
|
||||||
|
logger.warning(
|
||||||
|
"The 'transcript' parameter is missing from the Agent's prompt template. All ReAct agents "
|
||||||
|
"that go through multiple steps to reach a goal require this parameter. Please append {transcript} "
|
||||||
|
"to the end of the Agent's prompt template to ensure its proper functioning. A temporary prompt "
|
||||||
|
"template with {transcript} appended will be used for this run."
|
||||||
|
)
|
||||||
|
new_prompt_text = self.prompt_template.prompt_text + "\n {transcript}"
|
||||||
|
self.prompt_template = PromptTemplate(prompt=new_prompt_text)
|
||||||
|
|
||||||
|
elif unused_params:
|
||||||
|
logger.debug(
|
||||||
|
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||||
|
"prompt parameter resolver: %s. Note that these parameters are available for use if needed.",
|
||||||
|
list(unused_params),
|
||||||
|
)
|
||||||
|
@ -2,18 +2,17 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from events import Events
|
|
||||||
|
|
||||||
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
|
|
||||||
from test.conftest import MockRetriever, MockPromptNode
|
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from test.conftest import MockRetriever, MockPromptNode
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from events import Events
|
||||||
|
|
||||||
from haystack import BaseComponent, Answer, Document
|
from haystack import BaseComponent, Answer, Document
|
||||||
from haystack.agents import Agent, AgentStep
|
from haystack.agents import Agent, AgentStep
|
||||||
from haystack.agents.base import Tool, ToolsManager
|
from haystack.agents.base import Tool, ToolsManager
|
||||||
|
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
|
||||||
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
||||||
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
|
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
|
||||||
|
|
||||||
@ -356,3 +355,77 @@ def test_agent_token_streaming_handler():
|
|||||||
|
|
||||||
assert result == "test"
|
assert result == "test"
|
||||||
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
|
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_agent_prompt_template_parameter_has_transcript(caplog):
|
||||||
|
mock_prompt_node = Mock(spec=PromptNode)
|
||||||
|
prompt = PromptTemplate(prompt="I now have {query} as a template parameter but also {transcript}")
|
||||||
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||||
|
|
||||||
|
agent = Agent(prompt_node=mock_prompt_node)
|
||||||
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
||||||
|
assert "The 'transcript' parameter is missing from the Agent's prompt template" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_agent_prompt_template_has_no_transcript(caplog):
|
||||||
|
mock_prompt_node = Mock(spec=PromptNode)
|
||||||
|
prompt = PromptTemplate(prompt="I only have {query} as a template parameter but I am missing transcript variable")
|
||||||
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||||
|
agent = Agent(prompt_node=mock_prompt_node)
|
||||||
|
|
||||||
|
# We start with no transcript in the prompt template
|
||||||
|
assert "transcript" not in prompt.prompt_params
|
||||||
|
assert "transcript" not in agent.prompt_template.prompt_params
|
||||||
|
|
||||||
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
||||||
|
assert "The 'transcript' parameter is missing from the Agent's prompt template" in caplog.text
|
||||||
|
|
||||||
|
# now let's check again after adding the transcript
|
||||||
|
# query was there to begin with
|
||||||
|
assert "query" in agent.prompt_template.prompt_params
|
||||||
|
# transcript was added automatically for this prompt template and run
|
||||||
|
assert "transcript" in agent.prompt_template.prompt_params
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_agent_prompt_template_unused_parameters(caplog):
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
mock_prompt_node = Mock(spec=PromptNode)
|
||||||
|
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
||||||
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||||
|
agent = Agent(prompt_node=mock_prompt_node)
|
||||||
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript", "unused": "test"})
|
||||||
|
assert (
|
||||||
|
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||||
|
"prompt parameter resolver: ['unused']" in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_agent_prompt_template_multiple_unused_parameters(caplog):
|
||||||
|
caplog.set_level(logging.DEBUG)
|
||||||
|
mock_prompt_node = Mock(spec=PromptNode)
|
||||||
|
prompt = PromptTemplate(prompt="I now have strange {param_1} and {param_2} as template parameters")
|
||||||
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||||
|
agent = Agent(prompt_node=mock_prompt_node)
|
||||||
|
agent.check_prompt_template({"query": "test", "unused": "test"})
|
||||||
|
# order of parameters in the list not guaranteed, so we check for preamble of the message
|
||||||
|
assert (
|
||||||
|
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||||
|
"prompt parameter resolver" in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.unit
|
||||||
|
def test_agent_prompt_template_missing_parameters(caplog):
|
||||||
|
# in check_prompt_template we don't check that all prompt template parameters are filled
|
||||||
|
# prompt template resolution will do that and flag the missing parameters
|
||||||
|
# in check_prompt_template we check if some template parameters are not used
|
||||||
|
mock_prompt_node = Mock(spec=PromptNode)
|
||||||
|
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
||||||
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||||
|
agent = Agent(prompt_node=mock_prompt_node)
|
||||||
|
agent.check_prompt_template({"transcript": "test"})
|
||||||
|
assert not caplog.text
|
||||||
|
Loading…
x
Reference in New Issue
Block a user