From 089187ac8b47b24517580c23f70c9c36d1170e5f Mon Sep 17 00:00:00 2001 From: Vladimir Blagojevic Date: Wed, 21 Jun 2023 14:34:41 +0200 Subject: [PATCH] fix: Check Agent's prompt template variables and prompt resolver parameters are aligned (#5163) * Check Agent's prompt template parameters and prompt resolver parameters are aligned * Lower the logger warning * Automatically append transcript if needed * Amend flaky test --- haystack/agents/base.py | 52 +++++++++++++++++++----- test/agents/test_agent.py | 85 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 120 insertions(+), 17 deletions(-) diff --git a/haystack/agents/base.py b/haystack/agents/base.py index 4dcb40745..6e8ba453d 100644 --- a/haystack/agents/base.py +++ b/haystack/agents/base.py @@ -396,17 +396,15 @@ class Agent: # first resolve prompt template params template_params = self.prompt_parameters_resolver(query=query, agent=self, agent_step=current_step) - # if prompt node has no default prompt template, use agent's prompt template - if self.prompt_node.default_prompt_template is None: - prepared_prompt = next(self.prompt_template.fill(**template_params)) - prompt_node_response = self.prompt_node( - prepared_prompt, stream_handler=AgentTokenStreamingHandler(self.callback_manager) - ) - # otherwise, if prompt node has default prompt template, use it - else: - prompt_node_response = self.prompt_node( - stream_handler=AgentTokenStreamingHandler(self.callback_manager), **template_params - ) + # check for template parameters mismatch + self.check_prompt_template(template_params) + + # invoke via prompt node + prompt_node_response = self.prompt_node.prompt( + prompt_template=self.prompt_template, + stream_handler=AgentTokenStreamingHandler(self.callback_manager), + **template_params, + ) return prompt_node_response def create_agent_step(self, max_steps: Optional[int] = None) -> AgentStep: @@ -422,3 +420,35 @@ class Agent: return { k: v if isinstance(v, str) else next(iter(v)) for k, v in kwargs.items() if isinstance(v, (str, Iterable)) } + + def check_prompt_template(self, template_params: Dict[str, Any]) -> None: + """ + Verifies that the Agent's prompt template is adequately populated with the correct parameters + provided by the prompt parameter resolver. + + If template_params contains a parameter that is not specified in the prompt template, a warning is logged + at DEBUG level. Sometimes the prompt parameter resolver may provide additional parameters that are not + used by the prompt template. However, if the prompt parameter resolver provides a 'transcript' + parameter that is not used in the prompt template, an error is logged. + + :param template_params: The parameters provided by the prompt parameter resolver. + + """ + unused_params = set(template_params.keys()) - set(self.prompt_template.prompt_params) + + if "transcript" in unused_params: + logger.warning( + "The 'transcript' parameter is missing from the Agent's prompt template. All ReAct agents " + "that go through multiple steps to reach a goal require this parameter. Please append {transcript} " + "to the end of the Agent's prompt template to ensure its proper functioning. A temporary prompt " + "template with {transcript} appended will be used for this run." + ) + new_prompt_text = self.prompt_template.prompt_text + "\n {transcript}" + self.prompt_template = PromptTemplate(prompt=new_prompt_text) + + elif unused_params: + logger.debug( + "The Agent's prompt template does not utilize the following parameters provided by the " + "prompt parameter resolver: %s. Note that these parameters are available for use if needed.", + list(unused_params), + ) diff --git a/test/agents/test_agent.py b/test/agents/test_agent.py index 971e58621..a44b567ae 100644 --- a/test/agents/test_agent.py +++ b/test/agents/test_agent.py @@ -2,18 +2,17 @@ import logging import os import re from typing import Tuple -from unittest.mock import Mock, patch - -from events import Events - -from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger -from test.conftest import MockRetriever, MockPromptNode from unittest import mock +from unittest.mock import Mock, patch +from test.conftest import MockRetriever, MockPromptNode + import pytest +from events import Events from haystack import BaseComponent, Answer, Document from haystack.agents import Agent, AgentStep from haystack.agents.base import Tool, ToolsManager +from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger from haystack.nodes import PromptModel, PromptNode, PromptTemplate from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline @@ -356,3 +355,77 @@ def test_agent_token_streaming_handler(): assert result == "test" mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test" + + +@pytest.mark.unit +def test_agent_prompt_template_parameter_has_transcript(caplog): + mock_prompt_node = Mock(spec=PromptNode) + prompt = PromptTemplate(prompt="I now have {query} as a template parameter but also {transcript}") + mock_prompt_node.get_prompt_template.return_value = prompt + + agent = Agent(prompt_node=mock_prompt_node) + agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"}) + assert "The 'transcript' parameter is missing from the Agent's prompt template" not in caplog.text + + +@pytest.mark.unit +def test_agent_prompt_template_has_no_transcript(caplog): + mock_prompt_node = Mock(spec=PromptNode) + prompt = PromptTemplate(prompt="I only have {query} as a template parameter but I am missing transcript variable") + mock_prompt_node.get_prompt_template.return_value = prompt + agent = Agent(prompt_node=mock_prompt_node) + + # We start with no transcript in the prompt template + assert "transcript" not in prompt.prompt_params + assert "transcript" not in agent.prompt_template.prompt_params + + agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"}) + assert "The 'transcript' parameter is missing from the Agent's prompt template" in caplog.text + + # now let's check again after adding the transcript + # query was there to begin with + assert "query" in agent.prompt_template.prompt_params + # transcript was added automatically for this prompt template and run + assert "transcript" in agent.prompt_template.prompt_params + + +@pytest.mark.unit +def test_agent_prompt_template_unused_parameters(caplog): + caplog.set_level(logging.DEBUG) + mock_prompt_node = Mock(spec=PromptNode) + prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters") + mock_prompt_node.get_prompt_template.return_value = prompt + agent = Agent(prompt_node=mock_prompt_node) + agent.check_prompt_template({"query": "test", "transcript": "some fake transcript", "unused": "test"}) + assert ( + "The Agent's prompt template does not utilize the following parameters provided by the " + "prompt parameter resolver: ['unused']" in caplog.text + ) + + +@pytest.mark.unit +def test_agent_prompt_template_multiple_unused_parameters(caplog): + caplog.set_level(logging.DEBUG) + mock_prompt_node = Mock(spec=PromptNode) + prompt = PromptTemplate(prompt="I now have strange {param_1} and {param_2} as template parameters") + mock_prompt_node.get_prompt_template.return_value = prompt + agent = Agent(prompt_node=mock_prompt_node) + agent.check_prompt_template({"query": "test", "unused": "test"}) + # order of parameters in the list not guaranteed, so we check for preamble of the message + assert ( + "The Agent's prompt template does not utilize the following parameters provided by the " + "prompt parameter resolver" in caplog.text + ) + + +@pytest.mark.unit +def test_agent_prompt_template_missing_parameters(caplog): + # in check_prompt_template we don't check that all prompt template parameters are filled + # prompt template resolution will do that and flag the missing parameters + # in check_prompt_template we check if some template parameters are not used + mock_prompt_node = Mock(spec=PromptNode) + prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters") + mock_prompt_node.get_prompt_template.return_value = prompt + agent = Agent(prompt_node=mock_prompt_node) + agent.check_prompt_template({"transcript": "test"}) + assert not caplog.text