2023-02-21 14:27:40 +01:00
|
|
|
import logging
|
|
|
|
import os
|
2023-03-17 18:21:14 +01:00
|
|
|
import re
|
2023-02-21 14:27:40 +01:00
|
|
|
from typing import Tuple
|
2023-06-21 14:34:41 +02:00
|
|
|
from unittest import mock
|
2023-06-13 14:52:24 +02:00
|
|
|
from unittest.mock import Mock, patch
|
2023-03-28 09:41:50 +02:00
|
|
|
from test.conftest import MockRetriever, MockPromptNode
|
2023-06-21 14:34:41 +02:00
|
|
|
|
2023-02-21 14:27:40 +01:00
|
|
|
import pytest
|
2023-06-21 14:34:41 +02:00
|
|
|
from events import Events
|
2023-02-21 14:27:40 +01:00
|
|
|
|
2023-05-17 21:31:08 +02:00
|
|
|
from haystack import BaseComponent, Answer, Document
|
2023-03-17 18:21:14 +01:00
|
|
|
from haystack.agents import Agent, AgentStep
|
2023-05-03 16:45:40 +02:00
|
|
|
from haystack.agents.base import Tool, ToolsManager
|
2023-06-21 14:34:41 +02:00
|
|
|
from haystack.agents.types import AgentTokenStreamingHandler, AgentToolLogger
|
2023-02-21 14:27:40 +01:00
|
|
|
from haystack.nodes import PromptModel, PromptNode, PromptTemplate
|
|
|
|
from haystack.pipelines import ExtractiveQAPipeline, DocumentSearchPipeline, BaseStandardPipeline
|
|
|
|
|
|
|
|
|
2023-02-23 13:03:45 +01:00
|
|
|
@pytest.mark.unit
|
2023-06-23 09:26:06 +02:00
|
|
|
def test_add_and_overwrite_tool(caplog):
|
2023-02-21 14:27:40 +01:00
|
|
|
# Add a Node as a Tool to an Agent
|
|
|
|
agent = Agent(prompt_node=MockPromptNode())
|
|
|
|
retriever = MockRetriever()
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever,
|
|
|
|
description="useful for when you need to " "retrieve documents from your index",
|
|
|
|
)
|
|
|
|
)
|
2023-05-03 16:45:40 +02:00
|
|
|
assert len(agent.tm.tools) == 1
|
2023-02-21 14:27:40 +01:00
|
|
|
assert agent.has_tool(tool_name="Retriever")
|
2023-05-03 16:45:40 +02:00
|
|
|
assert isinstance(agent.tm.tools["Retriever"].pipeline_or_node, BaseComponent)
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add a Pipeline as a Tool to an Agent and overwrite the previously added Tool
|
|
|
|
retriever_pipeline = DocumentSearchPipeline(MockRetriever())
|
2023-06-23 09:26:06 +02:00
|
|
|
with caplog.at_level(logging.WARNING):
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever_pipeline,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
|
|
|
)
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
2023-06-23 09:26:06 +02:00
|
|
|
assert (
|
|
|
|
"The agent already has a tool named 'Retriever'. The new tool will overwrite the existing one." in caplog.text
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
2023-05-03 16:45:40 +02:00
|
|
|
assert len(agent.tm.tools) == 1
|
2023-02-21 14:27:40 +01:00
|
|
|
assert agent.has_tool(tool_name="Retriever")
|
2023-05-03 16:45:40 +02:00
|
|
|
assert isinstance(agent.tm.tools["Retriever"].pipeline_or_node, BaseStandardPipeline)
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
|
2023-02-23 13:03:45 +01:00
|
|
|
@pytest.mark.unit
|
2023-03-17 18:21:14 +01:00
|
|
|
def test_max_steps(caplog, monkeypatch):
|
|
|
|
# Run an Agent and stop because max_steps is reached
|
|
|
|
agent = Agent(prompt_node=MockPromptNode(), max_steps=3)
|
2023-02-21 14:27:40 +01:00
|
|
|
retriever = MockRetriever()
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
2023-03-10 18:07:44 +01:00
|
|
|
output_variable="documents",
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Let the Agent always choose "Retriever" as the Tool with "" as input
|
|
|
|
def mock_extract_tool_name_and_tool_input(self, pred: str) -> Tuple[str, str]:
|
|
|
|
return "Retriever", ""
|
|
|
|
|
2023-05-03 16:45:40 +02:00
|
|
|
monkeypatch.setattr(ToolsManager, "extract_tool_name_and_tool_input", mock_extract_tool_name_and_tool_input)
|
2023-02-21 14:27:40 +01:00
|
|
|
|
2023-03-17 18:21:14 +01:00
|
|
|
# Using max_steps as specified in the Agent's init method
|
2023-02-21 14:27:40 +01:00
|
|
|
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
|
|
|
result = agent.run("Where does Christelle live?")
|
|
|
|
assert result["answers"] == [Answer(answer="", type="generative")]
|
2023-03-15 18:26:35 +02:00
|
|
|
assert "maximum number of iterations (3)" in caplog.text.lower()
|
2023-02-21 14:27:40 +01:00
|
|
|
|
2023-03-17 18:21:14 +01:00
|
|
|
# Setting max_steps in the Agent's run method
|
2023-02-21 14:27:40 +01:00
|
|
|
with caplog.at_level(logging.WARN, logger="haystack.agents"):
|
2023-03-17 18:21:14 +01:00
|
|
|
result = agent.run("Where does Christelle live?", max_steps=2)
|
2023-02-21 14:27:40 +01:00
|
|
|
assert result["answers"] == [Answer(answer="", type="generative")]
|
2023-03-15 18:26:35 +02:00
|
|
|
assert "maximum number of iterations (2)" in caplog.text.lower()
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
|
2023-02-23 13:03:45 +01:00
|
|
|
@pytest.mark.unit
|
2023-02-21 14:27:40 +01:00
|
|
|
def test_run_tool():
|
|
|
|
agent = Agent(prompt_node=MockPromptNode())
|
|
|
|
retriever = MockRetriever()
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
2023-03-10 18:07:44 +01:00
|
|
|
output_variable="documents",
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
|
|
|
)
|
2023-03-17 18:21:14 +01:00
|
|
|
pn_response = "need to find out what city he was born.\nTool: Retriever\nTool Input: Where was Jeremy McKinnon born"
|
|
|
|
|
|
|
|
step = AgentStep(prompt_node_response=pn_response)
|
2023-05-03 16:45:40 +02:00
|
|
|
result = agent.tm.run_tool(step.prompt_node_response)
|
2023-03-10 18:07:44 +01:00
|
|
|
assert result == "[]" # empty list of documents
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
|
2023-06-15 08:43:20 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_tool_logger():
|
|
|
|
agent = Agent(prompt_node=MockPromptNode())
|
|
|
|
atl = AgentToolLogger(agent_events=agent.callback_manager, tool_events=agent.tm.callback_manager)
|
|
|
|
|
|
|
|
retriever = MockRetriever()
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
|
|
|
output_variable="documents",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
pn_response = "need to find out what city he was born.\nTool: Retriever\nTool Input: Where was Jeremy McKinnon born"
|
|
|
|
|
|
|
|
step = AgentStep(prompt_node_response=pn_response)
|
|
|
|
agent.tm.run_tool(step.prompt_node_response)
|
|
|
|
|
|
|
|
# Check that the AgentToolLogger collected the tool's output
|
|
|
|
assert len(atl.logs) == 1
|
|
|
|
tool_logging_event = atl.logs[0]
|
|
|
|
assert tool_logging_event["tool_name"] == "Retriever"
|
|
|
|
assert tool_logging_event["tool_input"] == "Where was Jeremy McKinnon born"
|
|
|
|
assert tool_logging_event["tool_output"] == "[]"
|
|
|
|
|
|
|
|
|
2023-02-23 13:03:45 +01:00
|
|
|
@pytest.mark.unit
|
2023-02-21 14:27:40 +01:00
|
|
|
def test_extract_final_answer():
|
2023-03-17 18:21:14 +01:00
|
|
|
match_examples = [
|
|
|
|
"have the final answer to the question.\nFinal Answer: Florida",
|
|
|
|
"Final Answer: 42 is the answer",
|
|
|
|
"Final Answer: 1234",
|
|
|
|
"Final Answer: Answer",
|
|
|
|
"Final Answer: This list: one and two and three",
|
|
|
|
"Final Answer:42",
|
|
|
|
"Final Answer: ",
|
|
|
|
"Final Answer: The answer is 99 ",
|
|
|
|
]
|
|
|
|
expected_answers = [
|
|
|
|
"Florida",
|
|
|
|
"42 is the answer",
|
|
|
|
"1234",
|
|
|
|
"Answer",
|
|
|
|
"This list: one and two and three",
|
|
|
|
"42",
|
|
|
|
"",
|
|
|
|
"The answer is 99",
|
|
|
|
]
|
|
|
|
|
|
|
|
for example, expected_answer in zip(match_examples, expected_answers):
|
2023-05-17 15:19:09 +02:00
|
|
|
agent_step = AgentStep(prompt_node_response=example, final_answer_pattern=r"Final Answer\s*:\s*(.*)")
|
|
|
|
final_answer = agent_step.final_answer(query="irrelevant")
|
|
|
|
assert final_answer["answers"][0].answer == expected_answer
|
2023-03-17 18:21:14 +01:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_final_answer_regex():
|
|
|
|
match_examples = [
|
|
|
|
"Final Answer: 42 is the answer",
|
|
|
|
"Final Answer: 1234",
|
|
|
|
"Final Answer: Answer",
|
|
|
|
"Final Answer: This list: one and two and three",
|
|
|
|
"Final Answer:42",
|
|
|
|
"Final Answer: ",
|
|
|
|
"Final Answer: The answer is 99 ",
|
|
|
|
]
|
|
|
|
|
|
|
|
non_match_examples = ["Final answer: 42 is the answer", "Final Answer", "The final answer is: 100"]
|
|
|
|
final_answer_pattern = r"Final Answer\s*:\s*(.*)"
|
|
|
|
for example in match_examples:
|
|
|
|
match = re.match(final_answer_pattern, example)
|
|
|
|
assert match is not None
|
|
|
|
|
|
|
|
for example in non_match_examples:
|
|
|
|
match = re.match(final_answer_pattern, example)
|
|
|
|
assert match is None
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
|
2023-03-10 18:07:44 +01:00
|
|
|
@pytest.mark.integration
|
|
|
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
|
|
|
|
def test_tool_result_extraction(reader, retriever_with_docs):
|
|
|
|
# Test that the result of a Tool is correctly extracted as a string
|
|
|
|
|
|
|
|
# Pipeline as a Tool
|
|
|
|
search = ExtractiveQAPipeline(reader, retriever_with_docs)
|
|
|
|
t = Tool(
|
|
|
|
name="Search",
|
|
|
|
pipeline_or_node=search,
|
|
|
|
description="useful for when you need to answer "
|
|
|
|
"questions about where people live. You "
|
|
|
|
"should ask targeted questions",
|
|
|
|
output_variable="answers",
|
|
|
|
)
|
|
|
|
result = t.run("Where does Christelle live?")
|
|
|
|
assert isinstance(result, str)
|
|
|
|
assert result == "Paris" or result == "Madrid"
|
|
|
|
|
|
|
|
# PromptNode as a Tool
|
2023-05-23 15:22:58 +02:00
|
|
|
pt = PromptTemplate("Here is a question: {query}, Answer:")
|
2023-03-10 18:07:44 +01:00
|
|
|
pn = PromptNode(default_prompt_template=pt)
|
|
|
|
|
|
|
|
t = Tool(name="Search", pipeline_or_node=pn, description="N/A", output_variable="results")
|
|
|
|
result = t.run(tool_input="What is the capital of Germany?")
|
|
|
|
assert isinstance(result, str)
|
|
|
|
assert "berlin" in result.lower()
|
|
|
|
|
|
|
|
# Retriever as a Tool
|
|
|
|
t = Tool(
|
|
|
|
name="Retriever",
|
|
|
|
pipeline_or_node=retriever_with_docs,
|
|
|
|
description="useful for when you need to retrieve documents from your index",
|
|
|
|
output_variable="documents",
|
|
|
|
)
|
|
|
|
result = t.run(tool_input="Where does Christelle live?")
|
|
|
|
assert isinstance(result, str)
|
|
|
|
assert "Christelle" in result
|
|
|
|
|
|
|
|
|
2023-05-09 20:26:59 +02:00
|
|
|
@pytest.mark.skip("FIXME")
|
2023-02-21 14:27:40 +01:00
|
|
|
@pytest.mark.integration
|
|
|
|
@pytest.mark.parametrize("reader", ["farm"], indirect=True)
|
|
|
|
@pytest.mark.parametrize("retriever_with_docs, document_store_with_docs", [("bm25", "memory")], indirect=True)
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not os.environ.get("OPENAI_API_KEY", None),
|
|
|
|
reason="Please export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
|
|
|
|
)
|
|
|
|
def test_agent_run(reader, retriever_with_docs, document_store_with_docs):
|
|
|
|
search = ExtractiveQAPipeline(reader, retriever_with_docs)
|
2023-04-27 14:53:15 +02:00
|
|
|
prompt_model = PromptModel(model_name_or_path="gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"))
|
2023-02-21 14:27:40 +01:00
|
|
|
prompt_node = PromptNode(model_name_or_path=prompt_model, stop_words=["Observation:"])
|
2023-04-27 14:53:15 +02:00
|
|
|
country_finder = PromptNode(
|
2023-02-21 14:27:40 +01:00
|
|
|
model_name_or_path=prompt_model,
|
|
|
|
default_prompt_template=PromptTemplate(
|
2023-05-23 15:22:58 +02:00
|
|
|
"When I give you a name of the city, respond with the country where the city is located.\n"
|
2023-04-27 14:53:15 +02:00
|
|
|
"City: Rome\nCountry: Italy\n"
|
|
|
|
"City: Berlin\nCountry: Germany\n"
|
|
|
|
"City: Belgrade\nCountry: Serbia\n"
|
2023-05-23 15:22:58 +02:00
|
|
|
"City: {query}?\nCountry: "
|
2023-02-21 14:27:40 +01:00
|
|
|
),
|
|
|
|
)
|
|
|
|
|
2023-05-17 15:19:09 +02:00
|
|
|
agent = Agent(prompt_node=prompt_node, max_steps=12)
|
2023-02-21 14:27:40 +01:00
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Search",
|
|
|
|
pipeline_or_node=search,
|
|
|
|
description="useful for when you need to answer "
|
|
|
|
"questions about where people live. You "
|
|
|
|
"should ask targeted questions",
|
2023-03-10 18:07:44 +01:00
|
|
|
output_variable="answers",
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
|
|
|
)
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
2023-04-27 14:53:15 +02:00
|
|
|
name="CountryFinder",
|
|
|
|
pipeline_or_node=country_finder,
|
|
|
|
description="useful for when you need to find the country where a city is located",
|
2023-02-21 14:27:40 +01:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-04-27 14:53:15 +02:00
|
|
|
result = agent.run("Where is Madrid?")
|
|
|
|
country = result["answers"][0].answer
|
|
|
|
assert "spain" in country.lower()
|
2023-02-21 14:27:40 +01:00
|
|
|
|
2023-04-27 14:53:15 +02:00
|
|
|
result = agent.run("In which country is the city where Christelle lives?")
|
|
|
|
country = result["answers"][0].answer
|
|
|
|
assert "france" in country.lower()
|
2023-02-21 14:27:40 +01:00
|
|
|
|
|
|
|
|
2023-03-28 09:41:50 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_update_hash():
|
2023-04-14 17:59:17 +02:00
|
|
|
agent = Agent(prompt_node=MockPromptNode(), prompt_template=mock.Mock())
|
2023-03-28 09:41:50 +02:00
|
|
|
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Search",
|
|
|
|
pipeline_or_node=mock.Mock(),
|
|
|
|
description="useful for when you need to answer "
|
|
|
|
"questions about where people live. You "
|
|
|
|
"should ask targeted questions",
|
|
|
|
output_variable="answers",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
|
|
|
|
agent.add_tool(
|
|
|
|
Tool(
|
|
|
|
name="Count",
|
|
|
|
pipeline_or_node=mock.Mock(),
|
|
|
|
description="useful for when you need to count how many characters are in a word. Ask only with a single word.",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
assert agent.hash == "d41d8cd98f00b204e9800998ecf8427e"
|
|
|
|
agent.update_hash()
|
|
|
|
assert agent.hash == "5ac8eca2f92c9545adcce3682b80d4c5"
|
2023-05-17 15:19:09 +02:00
|
|
|
|
|
|
|
|
2023-05-17 21:31:08 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
def test_tool_fails_processing_dict_result():
|
|
|
|
tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description")
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
tool._process_result({"answer": "answer"})
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_tool_processes_answer_result_and_document_result():
|
|
|
|
tool = Tool(name="name", pipeline_or_node=MockPromptNode(), description="description")
|
|
|
|
assert tool._process_result(Answer(answer="answer")) == "answer"
|
|
|
|
assert tool._process_result(Document(content="content")) == "content"
|
|
|
|
|
|
|
|
|
2023-06-08 13:42:28 +02:00
|
|
|
@pytest.mark.unit
|
|
|
|
@patch.object(PromptNode, "prompt")
|
|
|
|
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
|
|
|
def test_default_template_order(mock_model, mock_prompt):
|
|
|
|
pn = PromptNode("abc")
|
|
|
|
a = Agent(prompt_node=pn)
|
|
|
|
assert a.prompt_template.name == "zero-shot-react"
|
|
|
|
|
|
|
|
pn.default_prompt_template = "language-detection"
|
|
|
|
a = Agent(prompt_node=pn)
|
|
|
|
assert a.prompt_template.name == "language-detection"
|
|
|
|
|
|
|
|
a = Agent(prompt_node=pn, prompt_template="translation")
|
|
|
|
assert a.prompt_template.name == "translation"
|
2023-06-13 14:52:24 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_with_unknown_prompt_template():
|
|
|
|
prompt_node = Mock()
|
|
|
|
prompt_node.get_prompt_template.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="Prompt template 'invalid' not found"):
|
|
|
|
Agent(prompt_node=prompt_node, prompt_template="invalid")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_token_streaming_handler():
|
|
|
|
e = Events("on_new_token")
|
|
|
|
|
|
|
|
mock_callback = Mock()
|
|
|
|
e.on_new_token += mock_callback # register the mock callback to the event
|
|
|
|
|
|
|
|
handler = AgentTokenStreamingHandler(events=e)
|
|
|
|
result = handler("test")
|
|
|
|
|
|
|
|
assert result == "test"
|
|
|
|
mock_callback.assert_called_once_with("test") # assert that the mock callback was called with "test"
|
2023-06-21 14:34:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_prompt_template_parameter_has_transcript(caplog):
|
|
|
|
mock_prompt_node = Mock(spec=PromptNode)
|
|
|
|
prompt = PromptTemplate(prompt="I now have {query} as a template parameter but also {transcript}")
|
|
|
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
|
|
|
|
|
|
|
agent = Agent(prompt_node=mock_prompt_node)
|
|
|
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
|
|
|
assert "The 'transcript' parameter is missing from the Agent's prompt template" not in caplog.text
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_prompt_template_has_no_transcript(caplog):
|
|
|
|
mock_prompt_node = Mock(spec=PromptNode)
|
|
|
|
prompt = PromptTemplate(prompt="I only have {query} as a template parameter but I am missing transcript variable")
|
|
|
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
|
|
|
agent = Agent(prompt_node=mock_prompt_node)
|
|
|
|
|
|
|
|
# We start with no transcript in the prompt template
|
|
|
|
assert "transcript" not in prompt.prompt_params
|
|
|
|
assert "transcript" not in agent.prompt_template.prompt_params
|
|
|
|
|
|
|
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript"})
|
|
|
|
assert "The 'transcript' parameter is missing from the Agent's prompt template" in caplog.text
|
|
|
|
|
|
|
|
# now let's check again after adding the transcript
|
|
|
|
# query was there to begin with
|
|
|
|
assert "query" in agent.prompt_template.prompt_params
|
|
|
|
# transcript was added automatically for this prompt template and run
|
|
|
|
assert "transcript" in agent.prompt_template.prompt_params
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_prompt_template_unused_parameters(caplog):
|
|
|
|
caplog.set_level(logging.DEBUG)
|
|
|
|
mock_prompt_node = Mock(spec=PromptNode)
|
|
|
|
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
|
|
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
|
|
|
agent = Agent(prompt_node=mock_prompt_node)
|
|
|
|
agent.check_prompt_template({"query": "test", "transcript": "some fake transcript", "unused": "test"})
|
|
|
|
assert (
|
|
|
|
"The Agent's prompt template does not utilize the following parameters provided by the "
|
|
|
|
"prompt parameter resolver: ['unused']" in caplog.text
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_prompt_template_multiple_unused_parameters(caplog):
|
|
|
|
caplog.set_level(logging.DEBUG)
|
|
|
|
mock_prompt_node = Mock(spec=PromptNode)
|
|
|
|
prompt = PromptTemplate(prompt="I now have strange {param_1} and {param_2} as template parameters")
|
|
|
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
|
|
|
agent = Agent(prompt_node=mock_prompt_node)
|
|
|
|
agent.check_prompt_template({"query": "test", "unused": "test"})
|
|
|
|
# order of parameters in the list not guaranteed, so we check for preamble of the message
|
|
|
|
assert (
|
|
|
|
"The Agent's prompt template does not utilize the following parameters provided by the "
|
|
|
|
"prompt parameter resolver" in caplog.text
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_agent_prompt_template_missing_parameters(caplog):
|
|
|
|
# in check_prompt_template we don't check that all prompt template parameters are filled
|
|
|
|
# prompt template resolution will do that and flag the missing parameters
|
|
|
|
# in check_prompt_template we check if some template parameters are not used
|
|
|
|
mock_prompt_node = Mock(spec=PromptNode)
|
|
|
|
prompt = PromptTemplate(prompt="I now have {query} and {transcript} as template parameters")
|
|
|
|
mock_prompt_node.get_prompt_template.return_value = prompt
|
|
|
|
agent = Agent(prompt_node=mock_prompt_node)
|
|
|
|
agent.check_prompt_template({"transcript": "test"})
|
|
|
|
assert not caplog.text
|