2023-05-17 15:19:09 +02:00
|
|
|
import pytest
|
2023-06-20 13:09:21 +03:00
|
|
|
from unittest.mock import MagicMock, patch
|
2023-05-17 15:19:09 +02:00
|
|
|
|
2023-06-20 13:09:21 +03:00
|
|
|
from haystack.errors import AgentError
|
|
|
|
from haystack.agents.base import Tool
|
2023-05-17 15:19:09 +02:00
|
|
|
from haystack.agents.conversational import ConversationalAgent
|
|
|
|
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
|
2023-06-20 13:09:21 +03:00
|
|
|
from haystack.nodes import PromptNode
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture
|
|
|
|
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
|
|
|
def prompt_node(mock_model):
|
|
|
|
prompt_node = PromptNode()
|
|
|
|
return prompt_node
|
2023-05-17 15:19:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
2023-06-20 13:09:21 +03:00
|
|
|
def test_init_without_tools(prompt_node):
|
2023-06-13 14:52:24 +02:00
|
|
|
agent = ConversationalAgent(prompt_node)
|
2023-05-23 15:22:58 +02:00
|
|
|
|
2023-05-17 15:19:09 +02:00
|
|
|
# Test normal case
|
|
|
|
assert isinstance(agent.memory, ConversationMemory)
|
|
|
|
assert callable(agent.prompt_parameters_resolver)
|
|
|
|
assert agent.max_steps == 2
|
|
|
|
assert agent.final_answer_pattern == r"^([\s\S]+)$"
|
2023-06-20 13:09:21 +03:00
|
|
|
assert agent.prompt_template.name == "conversational-agent-without-tools"
|
2023-05-17 15:19:09 +02:00
|
|
|
|
|
|
|
# ConversationalAgent doesn't have tools
|
|
|
|
assert not agent.tm.tools
|
|
|
|
|
2023-05-23 15:22:58 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
2023-06-20 13:09:21 +03:00
|
|
|
def test_init_with_tools(prompt_node):
|
|
|
|
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
|
|
|
|
|
|
|
|
# Test normal case
|
|
|
|
assert isinstance(agent.memory, ConversationMemory)
|
|
|
|
assert callable(agent.prompt_parameters_resolver)
|
|
|
|
assert agent.max_steps == 5
|
|
|
|
assert agent.final_answer_pattern == r"Final Answer\s*:\s*(.*)"
|
|
|
|
assert agent.prompt_template.name == "conversational-agent"
|
|
|
|
assert agent.has_tool("ExampleTool")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_with_summary_memory(prompt_node):
|
2023-06-13 14:52:24 +02:00
|
|
|
# Test with summary memory
|
|
|
|
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
|
|
|
|
assert isinstance(agent.memory, ConversationSummaryMemory)
|
2023-05-17 15:19:09 +02:00
|
|
|
|
2023-05-23 15:22:58 +02:00
|
|
|
|
|
|
|
@pytest.mark.unit
|
2023-06-20 13:09:21 +03:00
|
|
|
def test_init_with_no_memory(prompt_node):
|
2023-06-13 14:52:24 +02:00
|
|
|
# Test with no memory
|
|
|
|
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
|
|
|
assert isinstance(agent.memory, NoMemory)
|
2023-05-17 15:19:09 +02:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
2023-06-20 13:09:21 +03:00
|
|
|
def test_init_with_custom_max_steps(prompt_node):
|
|
|
|
# Test with custom max step
|
|
|
|
agent = ConversationalAgent(prompt_node, max_steps=8)
|
|
|
|
assert agent.max_steps == 8
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_init_with_custom_prompt_template(prompt_node):
|
|
|
|
# Test with custom prompt template
|
|
|
|
agent = ConversationalAgent(prompt_node, prompt_template="translation")
|
|
|
|
assert agent.prompt_template.name == "translation"
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_run(prompt_node):
|
2023-06-13 14:52:24 +02:00
|
|
|
agent = ConversationalAgent(prompt_node)
|
|
|
|
|
|
|
|
# Mock the Agent run method
|
2023-06-20 13:09:21 +03:00
|
|
|
agent.run = MagicMock(return_value="Hello")
|
|
|
|
assert agent.run("query") == "Hello"
|
|
|
|
agent.run.assert_called_once_with("query")
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_add_tool(prompt_node):
|
|
|
|
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
|
|
|
|
# ConversationalAgent has tools
|
|
|
|
assert len(agent.tm.tools) == 1
|
|
|
|
|
|
|
|
# and add more tools if ConversationalAgent is initialized with tools
|
|
|
|
agent.add_tool(Tool("AnotherTool", lambda x: x, description="Example tool"))
|
|
|
|
assert len(agent.tm.tools) == 2
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.unit
|
|
|
|
def test_add_tool_not_allowed(prompt_node):
|
|
|
|
agent = ConversationalAgent(prompt_node)
|
|
|
|
# ConversationalAgent has no tools
|
|
|
|
assert not agent.tm.tools
|
2023-06-13 14:52:24 +02:00
|
|
|
|
2023-06-20 13:09:21 +03:00
|
|
|
# and can't add tools when a ConversationalAgent is initialized without tools
|
|
|
|
with pytest.raises(
|
|
|
|
AgentError, match="You cannot add tools after initializing the ConversationalAgent without any tools."
|
|
|
|
):
|
|
|
|
agent.add_tool(Tool("ExampleTool", lambda x: x, description="Example tool"))
|