haystack/test/agents/test_conversational_agent.py
Bilge Yücel 6a1b6b1ae3
feat: Update ConversationalAgent (#5065)
* feat: Update ConversationalAgent

* Add Tools
* Add test
* Change default params

* fix tests

* Fix circular import error
* Update conversational-agent prompt
* Add conversational-agent-without-tools to legacy list

* Add warning to add tools to conversational agent

* Add callable tools

* Add example script

* Fix linter errors

* Update ConversationalAgent depending on the existance of tools

* Initialize the base Agent with different arguments when there's tool
* Inject memory to the prompt in both cases, update prompts accordingly

* Override the add_tools method to prevent adding tools to ConversationalAgent without tools

* Update test

* Fix linter error

* Remove unused import

* Update docstrings and api reference

* Fix imports and doc string code snippet

* docstrings update

* Update conversational.py

* Mock PromptNode

* Prevent circular import error

* Add max_steps to the ConversationalAgent

* Update resolver description

* Add prompt_template as parameter

* Change docstring

---------

Co-authored-by: Darja Fokina <daria.f93@gmail.com>
2023-06-20 13:09:21 +03:00

106 lines
3.5 KiB
Python

import pytest
from unittest.mock import MagicMock, patch
from haystack.errors import AgentError
from haystack.agents.base import Tool
from haystack.agents.conversational import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
from haystack.nodes import PromptNode
@pytest.fixture
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def prompt_node(mock_model):
prompt_node = PromptNode()
return prompt_node
@pytest.mark.unit
def test_init_without_tools(prompt_node):
agent = ConversationalAgent(prompt_node)
# Test normal case
assert isinstance(agent.memory, ConversationMemory)
assert callable(agent.prompt_parameters_resolver)
assert agent.max_steps == 2
assert agent.final_answer_pattern == r"^([\s\S]+)$"
assert agent.prompt_template.name == "conversational-agent-without-tools"
# ConversationalAgent doesn't have tools
assert not agent.tm.tools
@pytest.mark.unit
def test_init_with_tools(prompt_node):
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
# Test normal case
assert isinstance(agent.memory, ConversationMemory)
assert callable(agent.prompt_parameters_resolver)
assert agent.max_steps == 5
assert agent.final_answer_pattern == r"Final Answer\s*:\s*(.*)"
assert agent.prompt_template.name == "conversational-agent"
assert agent.has_tool("ExampleTool")
@pytest.mark.unit
def test_init_with_summary_memory(prompt_node):
# Test with summary memory
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
@pytest.mark.unit
def test_init_with_no_memory(prompt_node):
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit
def test_init_with_custom_max_steps(prompt_node):
# Test with custom max step
agent = ConversationalAgent(prompt_node, max_steps=8)
assert agent.max_steps == 8
@pytest.mark.unit
def test_init_with_custom_prompt_template(prompt_node):
# Test with custom prompt template
agent = ConversationalAgent(prompt_node, prompt_template="translation")
assert agent.prompt_template.name == "translation"
@pytest.mark.unit
def test_run(prompt_node):
agent = ConversationalAgent(prompt_node)
# Mock the Agent run method
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")
@pytest.mark.unit
def test_add_tool(prompt_node):
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
# ConversationalAgent has tools
assert len(agent.tm.tools) == 1
# and add more tools if ConversationalAgent is initialized with tools
agent.add_tool(Tool("AnotherTool", lambda x: x, description="Example tool"))
assert len(agent.tm.tools) == 2
@pytest.mark.unit
def test_add_tool_not_allowed(prompt_node):
agent = ConversationalAgent(prompt_node)
# ConversationalAgent has no tools
assert not agent.tm.tools
# and can't add tools when a ConversationalAgent is initialized without tools
with pytest.raises(
AgentError, match="You cannot add tools after initializing the ConversationalAgent without any tools."
):
agent.add_tool(Tool("ExampleTool", lambda x: x, description="Example tool"))