mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-14 18:52:11 +00:00
fix: Check Agent's prompt template variables and prompt resolver parameters are aligned (#5163)
* Check Agent's prompt template parameters and prompt resolver parameters are aligned * Lower the logger warning * Automatically append transcript if needed * Amend flaky test
This commit is contained in:
parent
6a1b6b1ae3
commit
089187ac8b
@ -396,16 +396,14 @@ class Agent:
|
||||
# first resolve prompt template params
|
||||
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)
|
||||
|
||||
# if prompt node has no default prompt template, use agent's prompt template
|
||||
if self.prompt_node.default_prompt_template is None:
|
||||
prepared_prompt = next(self.prompt_template.fill(**template_params))
|
||||
prompt_node_response = self.prompt_node(
|
||||
prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager)
|
||||
)
|
||||
# otherwise, if prompt node has default prompt template, use it
|
||||
else:
|
||||
prompt_node_response = self.prompt_node(
|
||||
stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params
|
||||
# check for template parameters mismatch
|
||||
self.check_prompt_template(template_params)
|
||||
|
||||
# invoke via prompt node
|
||||
prompt_node_response = self.prompt_node.prompt(
|
||||
prompt_template=self.prompt_template,
|
||||
stream_handler=AgentTokenStreamingHandler(self.callback_manager),
|
||||
**template_params,
|
||||
)
|
||||
return prompt_node_response
|
||||
|
||||
@ -422,3 +420,35 @@ class Agent:
|
||||
return {
|
||||
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
|
||||
}
|
||||
|
||||
def check_prompt_template(self, template_params: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Verifies that the Agent's prompt template is adequately populated with the correct parameters
|
||||
provided by the prompt parameter resolver.
|
||||
|
||||
If template_params contains a parameter that is not specified in the prompt template, a warning is logged
|
||||
at DEBUG level. Sometimes the prompt parameter resolver may provide additional parameters that are not
|
||||
used by the prompt template. However, if the prompt parameter resolver provides a 'transcript'
|
||||
parameter that is not used in the prompt template, an error is logged.
|
||||
|
||||
:param template_params: The parameters provided by the prompt parameter resolver.
|
||||
|
||||
"""
|
||||
unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params)
|
||||
|
||||
if "transcript" in unused_params:
|
||||
logger.warning(
|
||||
"The 'transcript' parameter is missing from the Agent's prompt template. All ReAct agents "
|
||||
"that go through multiple steps to reach a goal require this parameter. Please append {transcript} "
|
||||
"to the end of the Agent's prompt template to ensure its proper functioning. A temporary prompt "
|
||||
"template with {transcript} appended will be used for this run."
|
||||
)
|
||||
new_prompt_text = self.prompt_template.prompt_text + "\n {transcript}"
|
||||
self.prompt_template = PromptTemplate(prompt=new_prompt_text)
|
||||
|
||||
elif unused_params:
|
||||
logger.debug(
|
||||
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||
"prompt parameter resolver: %s. Note that these parameters are available for use if needed.",
|
||||
list(unused_params),
|
||||
)
|
||||
|
@ -2,18 +2,17 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Tuple
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from events import Events
|
||||
|
||||
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
|
||||
from test.conftest import MockRetriever, MockPromptNode
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock, patch
|
||||
from test.conftest import MockRetriever, MockPromptNode
|
||||
|
||||
import pytest
|
||||
from events import Events
|
||||
|
||||
from haystack import BaseComponent, Answer, Document
|
||||
from haystack.agents import Agent, AgentStep
|
||||
from haystack.agents.base import Tool, ToolsManager
|
||||
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
|
||||
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
||||
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
|
||||
|
||||
@ -356,3 +355,77 @@ def test_agent_token_streaming_handler():
|
||||
|
||||
assert result == "test"
|
||||
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_prompt_template_parameter_has_transcript(caplog):
|
||||
mock_prompt_node = Mock(spec=PromptNode)
|
||||
prompt = PromptTemplate(prompt="I now have {query} as a template parameter but also {transcript}")
|
||||
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||
|
||||
agent = Agent(prompt_node=mock_prompt_node)
|
||||
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
||||
assert "The 'transcript' parameter is missing from the Agent's prompt template" not in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_prompt_template_has_no_transcript(caplog):
|
||||
mock_prompt_node = Mock(spec=PromptNode)
|
||||
prompt = PromptTemplate(prompt="I only have {query} as a template parameter but I am missing transcript variable")
|
||||
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||
agent = Agent(prompt_node=mock_prompt_node)
|
||||
|
||||
# We start with no transcript in the prompt template
|
||||
assert "transcript" not in prompt.prompt_params
|
||||
assert "transcript" not in agent.prompt_template.prompt_params
|
||||
|
||||
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
||||
assert "The 'transcript' parameter is missing from the Agent's prompt template" in caplog.text
|
||||
|
||||
# now let's check again after adding the transcript
|
||||
# query was there to begin with
|
||||
assert "query" in agent.prompt_template.prompt_params
|
||||
# transcript was added automatically for this prompt template and run
|
||||
assert "transcript" in agent.prompt_template.prompt_params
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_prompt_template_unused_parameters(caplog):
|
||||
caplog.set_level(logging.DEBUG)
|
||||
mock_prompt_node = Mock(spec=PromptNode)
|
||||
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
||||
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||
agent = Agent(prompt_node=mock_prompt_node)
|
||||
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript", "unused": "test"})
|
||||
assert (
|
||||
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||
"prompt parameter resolver: ['unused']" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_prompt_template_multiple_unused_parameters(caplog):
|
||||
caplog.set_level(logging.DEBUG)
|
||||
mock_prompt_node = Mock(spec=PromptNode)
|
||||
prompt = PromptTemplate(prompt="I now have strange {param_1} and {param_2} as template parameters")
|
||||
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||
agent = Agent(prompt_node=mock_prompt_node)
|
||||
agent.check_prompt_template({"query": "test", "unused": "test"})
|
||||
# order of parameters in the list not guaranteed, so we check for preamble of the message
|
||||
assert (
|
||||
"The Agent's prompt template does not utilize the following parameters provided by the "
|
||||
"prompt parameter resolver" in caplog.text
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_agent_prompt_template_missing_parameters(caplog):
|
||||
# in check_prompt_template we don't check that all prompt template parameters are filled
|
||||
# prompt template resolution will do that and flag the missing parameters
|
||||
# in check_prompt_template we check if some template parameters are not used
|
||||
mock_prompt_node = Mock(spec=PromptNode)
|
||||
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
||||
mock_prompt_node.get_prompt_template.return_value = prompt
|
||||
agent = Agent(prompt_node=mock_prompt_node)
|
||||
agent.check_prompt_template({"transcript": "test"})
|
||||
assert not caplog.text
|
||||
|
Loading…
x
Reference in New Issue
Block a user