fix: Check Agent's prompt template variables and prompt resolver parameters are aligned (#5163)

* Check Agent's prompt template parameters and prompt resolver parameters are aligned

* Lower the logger warning

* Automatically append transcript if needed

* Amend flaky test
This commit is contained in:
Vladimir Blagojevic 2023-06-21 14:34:41 +02:00 committed by GitHub
parent 6a1b6b1ae3
commit 089187ac8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 120 additions and 17 deletions

View File

@ -396,17 +396,15 @@ class Agent:
# first resolve prompt template params # first resolve prompt template params
template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step) template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step)
# if prompt node has no default prompt template, use agent's prompt template # check for template parameters mismatch
if self.prompt_node.default_prompt_template is None: self.check_prompt_template(template_params)
prepared_prompt = next(self.prompt_template.fill(**template_params))
prompt_node_response = self.prompt_node( # invoke via prompt node
prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager) prompt_node_response = self.prompt_node.prompt(
) prompt_template=self.prompt_template,
# otherwise, if prompt node has default prompt template, use it stream_handler=AgentTokenStreamingHandler(self.callback_manager),
else: **template_params,
prompt_node_response = self.prompt_node( )
stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params
)
return prompt_node_response return prompt_node_response
def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep:
@ -422,3 +420,35 @@ class Agent:
return { return {
k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable)) k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable))
} }
def check_prompt_template(self, template_params: Dict[str, Any]) -> None:
"""
Verifies that the Agent's prompt template is adequately populated with the correct parameters
provided by the prompt parameter resolver.
If template_params contains a parameter that is not specified in the prompt template, a warning is logged
at DEBUG level. Sometimes the prompt parameter resolver may provide additional parameters that are not
used by the prompt template. However, if the prompt parameter resolver provides a 'transcript'
parameter that is not used in the prompt template, an error is logged.
:param template_params: The parameters provided by the prompt parameter resolver.
"""
unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params)
if "transcript" in unused_params:
logger.warning(
"The 'transcript' parameter is missing from the Agent's prompt template. All ReAct agents "
"that go through multiple steps to reach a goal require this parameter. Please append {transcript} "
"to the end of the Agent's prompt template to ensure its proper functioning. A temporary prompt "
"template with {transcript} appended will be used for this run."
)
new_prompt_text = self.prompt_template.prompt_text + "\n {transcript}"
self.prompt_template = PromptTemplate(prompt=new_prompt_text)
elif unused_params:
logger.debug(
"The Agent's prompt template does not utilize the following parameters provided by the "
"prompt parameter resolver: %s. Note that these parameters are available for use if needed.",
list(unused_params),
)

View File

@ -2,18 +2,17 @@ import logging
import os import os
import re import re
from typing import Tuple from typing import Tuple
from unittest.mock import Mock, patch
from events import Events
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
from test.conftest import MockRetriever, MockPromptNode
from unittest import mock from unittest import mock
from unittest.mock import Mock, patch
from test.conftest import MockRetriever, MockPromptNode
import pytest import pytest
from events import Events
from haystack import BaseComponent, Answer, Document from haystack import BaseComponent, Answer, Document
from haystack.agents import Agent, AgentStep from haystack.agents import Agent, AgentStep
from haystack.agents.base import Tool, ToolsManager from haystack.agents.base import Tool, ToolsManager
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
from haystack.nodes import PromptModel, PromptNode, PromptTemplate from haystack.nodes import PromptModel, PromptNode, PromptTemplate
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
@ -356,3 +355,77 @@ def test_agent_token_streaming_handler():
assert result == "test" assert result == "test"
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test" mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
@pytest.mark.unit
def test_agent_prompt_template_parameter_has_transcript(caplog):
mock_prompt_node = Mock(spec=PromptNode)
prompt = PromptTemplate(prompt="I now have {query} as a template parameter but also {transcript}")
mock_prompt_node.get_prompt_template.return_value = prompt
agent = Agent(prompt_node=mock_prompt_node)
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
assert "The 'transcript' parameter is missing from the Agent's prompt template" not in caplog.text
@pytest.mark.unit
def test_agent_prompt_template_has_no_transcript(caplog):
mock_prompt_node = Mock(spec=PromptNode)
prompt = PromptTemplate(prompt="I only have {query} as a template parameter but I am missing transcript variable")
mock_prompt_node.get_prompt_template.return_value = prompt
agent = Agent(prompt_node=mock_prompt_node)
# We start with no transcript in the prompt template
assert "transcript" not in prompt.prompt_params
assert "transcript" not in agent.prompt_template.prompt_params
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
assert "The 'transcript' parameter is missing from the Agent's prompt template" in caplog.text
# now let's check again after adding the transcript
# query was there to begin with
assert "query" in agent.prompt_template.prompt_params
# transcript was added automatically for this prompt template and run
assert "transcript" in agent.prompt_template.prompt_params
@pytest.mark.unit
def test_agent_prompt_template_unused_parameters(caplog):
caplog.set_level(logging.DEBUG)
mock_prompt_node = Mock(spec=PromptNode)
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
mock_prompt_node.get_prompt_template.return_value = prompt
agent = Agent(prompt_node=mock_prompt_node)
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript", "unused": "test"})
assert (
"The Agent's prompt template does not utilize the following parameters provided by the "
"prompt parameter resolver: ['unused']" in caplog.text
)
@pytest.mark.unit
def test_agent_prompt_template_multiple_unused_parameters(caplog):
caplog.set_level(logging.DEBUG)
mock_prompt_node = Mock(spec=PromptNode)
prompt = PromptTemplate(prompt="I now have strange {param_1} and {param_2} as template parameters")
mock_prompt_node.get_prompt_template.return_value = prompt
agent = Agent(prompt_node=mock_prompt_node)
agent.check_prompt_template({"query": "test", "unused": "test"})
# order of parameters in the list not guaranteed, so we check for preamble of the message
assert (
"The Agent's prompt template does not utilize the following parameters provided by the "
"prompt parameter resolver" in caplog.text
)
@pytest.mark.unit
def test_agent_prompt_template_missing_parameters(caplog):
# in check_prompt_template we don't check that all prompt template parameters are filled
# prompt template resolution will do that and flag the missing parameters
# in check_prompt_template we check if some template parameters are not used
mock_prompt_node = Mock(spec=PromptNode)
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
mock_prompt_node.get_prompt_template.return_value = prompt
agent = Agent(prompt_node=mock_prompt_node)
agent.check_prompt_template({"transcript": "test"})
assert not caplog.text