mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-12-25 14:08:27 +00:00
feat: Update ConversationalAgent (#5065)
* feat: Update ConversationalAgent * Add Tools * Add test * Change default params * fix tests * Fix circular import error * Update conversational-agent prompt * Add conversational-agent-without-tools to legacy list * Add warning to add tools to conversational agent * Add callable tools * Add example script * Fix linter errors * Update ConversationalAgent depending on the existance of tools * Initialize the base Agent with different arguments when there's tool * Inject memory to the prompt in both cases, update prompts accordingly * Override the add_tools method to prevent adding tools to ConversationalAgent without tools * Update test * Fix linter error * Remove unused import * Update docstrings and api reference * Fix imports and doc string code snippet * docstrings update * Update conversational.py * Mock PromptNode * Prevent circular import error * Add max_steps to the ConversationalAgent * Update resolver description * Add prompt_template as parameter * Change docstring --------- Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
parent
30fdf2b5df
commit
6a1b6b1ae3
@ -1,7 +1,7 @@
|
||||
loaders:
|
||||
- type: python
|
||||
search_path: [../../../haystack/agents]
|
||||
modules: ['base']
|
||||
modules: ['base', 'conversational']
|
||||
ignore_when_discovered: ['__init__']
|
||||
processors:
|
||||
- type: filter
|
||||
|
||||
@ -1,17 +1,66 @@
|
||||
import os
|
||||
|
||||
from haystack.agents.base import Tool
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.nodes import PromptNode
|
||||
from haystack.agents.memory import ConversationSummaryMemory
|
||||
from haystack.nodes import PromptNode, WebRetriever, PromptTemplate
|
||||
from haystack.pipelines import WebQAPipeline
|
||||
from haystack.agents.types import Color
|
||||
|
||||
pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
|
||||
agent = ConversationalAgent(pn)
|
||||
search_api_key = os.environ.get("SEARCH_API_KEY")
|
||||
if not search_api_key:
|
||||
raise ValueError("Please set the SEARCH_API_KEY environment variable")
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if not openai_api_key:
|
||||
raise ValueError("Please set the OPENAI_API_KEY environment variable")
|
||||
|
||||
while True:
|
||||
user_input = input("Human (type 'exit' or 'quit' to quit, 'memory' for agent's memory): ")
|
||||
if user_input.lower() == "exit" or user_input.lower() == "quit":
|
||||
break
|
||||
if user_input.lower() == "memory":
|
||||
print("\nMemory:\n", agent.memory.load())
|
||||
else:
|
||||
assistant_response = agent.run(user_input)
|
||||
print("\nAssistant:", assistant_response)
|
||||
web_prompt = """
|
||||
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question.
|
||||
Provide a clear and concise answer, no longer than 10-20 words.
|
||||
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer:
|
||||
"""
|
||||
|
||||
web_prompt_node = PromptNode(
|
||||
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt=web_prompt), api_key=openai_api_key
|
||||
)
|
||||
|
||||
web_retriever = WebRetriever(api_key=search_api_key, top_search_results=3, mode="snippets")
|
||||
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=web_prompt_node)
|
||||
web_qa_tool = Tool(
|
||||
name="Search",
|
||||
pipeline_or_node=pipeline,
|
||||
description="useful for when you need to Google questions if you cannot find answers in the the previous conversation",
|
||||
output_variable="results",
|
||||
logging_color=Color.MAGENTA,
|
||||
)
|
||||
|
||||
conversational_agent_prompt_node = PromptNode(
|
||||
"gpt-3.5-turbo",
|
||||
api_key=openai_api_key,
|
||||
max_length=256,
|
||||
stop_words=["Observation:"],
|
||||
model_kwargs={"temperature": 0.5, "top_p": 0.9},
|
||||
)
|
||||
memory = ConversationSummaryMemory(conversational_agent_prompt_node, summary_frequency=2)
|
||||
|
||||
conversational_agent = ConversationalAgent(
|
||||
prompt_node=conversational_agent_prompt_node, tools=[web_qa_tool], memory=memory
|
||||
)
|
||||
|
||||
test = False
|
||||
if test:
|
||||
questions = [
|
||||
"Why was Jamie Foxx recently hospitalized?",
|
||||
"Where was he hospitalized?",
|
||||
"What movie was he filming at the time?",
|
||||
"Who is Jamie's female co-star in the movie he was filing at that time?",
|
||||
"Tell me more about her, who is her partner?",
|
||||
]
|
||||
for question in questions:
|
||||
conversational_agent.run(question)
|
||||
else:
|
||||
while True:
|
||||
user_input = input("\nHuman (type 'exit' or 'quit' to quit): ")
|
||||
if user_input.lower() == "exit" or user_input.lower() == "quit":
|
||||
break
|
||||
response = conversational_agent.run(user_input)
|
||||
|
||||
@ -13,7 +13,7 @@ from haystack.agents.memory import Memory, NoMemory
|
||||
from haystack.telemetry import send_event
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
from haystack.agents.types import Color, AgentTokenStreamingHandler
|
||||
from haystack.agents.utils import print_text
|
||||
from haystack.agents.utils import print_text, react_parameter_resolver
|
||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||
from haystack.pipelines import (
|
||||
BaseStandardPipeline,
|
||||
@ -63,6 +63,7 @@ class Tool:
|
||||
TranslationWrapperPipeline,
|
||||
RetrieverQuestionGenerationPipeline,
|
||||
WebQAPipeline,
|
||||
Callable[[Any], str],
|
||||
],
|
||||
description: str,
|
||||
output_variable: str = "results",
|
||||
@ -85,6 +86,8 @@ class Tool:
|
||||
result = self.pipeline_or_node.run(query=tool_input, params=params)
|
||||
elif isinstance(self.pipeline_or_node, BaseRetriever):
|
||||
result = self.pipeline_or_node.run(query=tool_input, root_node="Query")
|
||||
elif callable(self.pipeline_or_node):
|
||||
result = self.pipeline_or_node(tool_input)
|
||||
else:
|
||||
result = self.pipeline_or_node.run(query=tool_input)
|
||||
return self._process_result(result)
|
||||
@ -266,14 +269,6 @@ class Agent:
|
||||
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
|
||||
)
|
||||
self.prompt_template = resolved_prompt_template
|
||||
react_parameter_resolver: Callable[
|
||||
[str, Agent, AgentStep, Dict[str, Any]], Dict[str, Any]
|
||||
] = lambda query, agent, agent_step, **kwargs: {
|
||||
"query": query,
|
||||
"tool_names": agent.tm.get_tool_names(),
|
||||
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
|
||||
"transcript": agent_step.transcript,
|
||||
}
|
||||
self.prompt_parameters_resolver = (
|
||||
prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver
|
||||
)
|
||||
|
||||
@ -1,63 +1,110 @@
|
||||
from typing import Optional, Callable, Union
|
||||
from typing import Optional, List, Union
|
||||
import logging
|
||||
|
||||
from haystack.agents import Agent
|
||||
from haystack.errors import AgentError
|
||||
from haystack.agents.base import Tool, ToolsManager, Agent
|
||||
from haystack.agents.memory import Memory, ConversationMemory
|
||||
from haystack.nodes import PromptNode, PromptTemplate
|
||||
from haystack.agents.utils import conversational_agent_parameter_resolver, agent_without_tools_parameter_resolver
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConversationalAgent(Agent):
|
||||
"""
|
||||
A conversational agent that can be used to build a conversational chat applications.
|
||||
A ConversationalAgent is an extension of the Agent class that enables the use of tools with several default parameters.
|
||||
ConversationalAgent can manage a set of tools and seamlessly integrate them into the conversation.
|
||||
If no tools are provided, the agent will be initialized to have a basic chat application.
|
||||
|
||||
Here is an example of how you can create a simple chat application:
|
||||
```
|
||||
Here is an example how you can create a chat application with tools:
|
||||
```python
|
||||
import os
|
||||
|
||||
from haystack.agents.base import ConversationalAgent
|
||||
from haystack.agents.memory import ConversationSummaryMemory
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.nodes import PromptNode
|
||||
from haystack.agents.base import ToolsManager, Tool
|
||||
|
||||
pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
|
||||
agent = ConversationalAgent(pn, memory=ConversationSummaryMemory(pn))
|
||||
# Initialize a PromptNode and a ToolsManager with the desired tools
|
||||
prompt_node = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
|
||||
tools = [Tool(name="ExampleTool", pipeline_or_node=example_tool_node)]
|
||||
|
||||
# Create the ConversationalAgent instance
|
||||
agent = ConversationalAgent(prompt_node, tools=tools)
|
||||
|
||||
# Use the agent in a chat application
|
||||
while True:
|
||||
user_input = input("Human (type 'exit' or 'quit' to quit): ")
|
||||
if user_input.lower() == "exit" or user_input.lower() == "quit":
|
||||
break
|
||||
elif user_input.lower() == "memory":
|
||||
print("\nMemory:\n", agent.memory.load())
|
||||
else:
|
||||
assistant_response = agent.run(user_input)
|
||||
print("\nAssistant:", assistant_response)
|
||||
|
||||
```
|
||||
|
||||
If you don't want to have any tools in your chat app, you can create a ConversationalAgent only with a PromptNode:
|
||||
```python
|
||||
import os
|
||||
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.nodes import PromptNode
|
||||
|
||||
# Initialize a PromptNode
|
||||
prompt_node = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
|
||||
|
||||
# Create the ConversationalAgent instance
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
```
|
||||
|
||||
If you're looking for more customization, check out [Agent](https://docs.haystack.deepset.ai/reference/agent-api).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_node: PromptNode,
|
||||
prompt_template: Optional[Union[str, PromptTemplate]] = None,
|
||||
tools: Optional[List[Tool]] = None,
|
||||
memory: Optional[Memory] = None,
|
||||
prompt_parameters_resolver: Optional[Callable] = None,
|
||||
max_steps: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Creates a new ConversationalAgent instance
|
||||
Creates a new ConversationalAgent instance.
|
||||
|
||||
:param prompt_node: A PromptNode used to communicate with LLM.
|
||||
:param prompt_template: A string or PromptTemplate instance to use as the prompt template. If no prompt_template
|
||||
is provided, the agent will use the "conversational-agent" template.
|
||||
:param memory: A memory instance for storing conversation history and other relevant data, defaults to
|
||||
ConversationMemory.
|
||||
:param prompt_parameters_resolver: An optional callable for resolving prompt template parameters,
|
||||
defaults to a callable that returns a dictionary with the query and the conversation history.
|
||||
:param prompt_node: A PromptNode used by Agent to decide which tool to use and what input to provide to it
|
||||
in each iteration. If there are no tools added, the model specified with PromptNode will be used for chatting.
|
||||
:param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's
|
||||
used for keeping the chat history, generating thoughts and choosing tools (if provided) to answer queries. It defaults to
|
||||
to "conversational-agent" if there is at least one tool provided and "conversational-agent-without-tools" otherwise.
|
||||
:param tools: A list of tools to use in the Agent. Each tool must have a unique name.
|
||||
:param memory: A memory object for storing conversation history and other relevant data, defaults to
|
||||
ConversationMemory if no memory is provided.
|
||||
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer. It defaults to 5 if there is at least one tool provided and 2 otherwise.
|
||||
"""
|
||||
super().__init__(
|
||||
prompt_node=prompt_node,
|
||||
prompt_template=prompt_template or "conversational-agent",
|
||||
max_steps=2,
|
||||
memory=memory if memory else ConversationMemory(),
|
||||
prompt_parameters_resolver=prompt_parameters_resolver
|
||||
if prompt_parameters_resolver
|
||||
else lambda query, agent, **kwargs: {"query": query, "history": agent.memory.load()},
|
||||
final_answer_pattern=r"^([\s\S]+)$",
|
||||
)
|
||||
|
||||
if tools:
|
||||
super().__init__(
|
||||
prompt_node=prompt_node,
|
||||
memory=memory if memory else ConversationMemory(),
|
||||
tools_manager=ToolsManager(tools=tools),
|
||||
max_steps=max_steps if max_steps else 5,
|
||||
prompt_template=prompt_template if prompt_template else "conversational-agent",
|
||||
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
|
||||
prompt_parameters_resolver=conversational_agent_parameter_resolver,
|
||||
)
|
||||
else:
|
||||
logger.warning("ConversationalAgent is created without tools")
|
||||
|
||||
super().__init__(
|
||||
prompt_node=prompt_node,
|
||||
memory=memory if memory else ConversationMemory(),
|
||||
max_steps=max_steps if max_steps else 2,
|
||||
prompt_template=prompt_template if prompt_template else "conversational-agent-without-tools",
|
||||
final_answer_pattern=r"^([\s\S]+)$",
|
||||
prompt_parameters_resolver=agent_without_tools_parameter_resolver,
|
||||
)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
if len(self.tm.tools) == 0:
|
||||
raise AgentError(
|
||||
"You cannot add tools after initializing the ConversationalAgent without any tools. If you want to add tools, reinitailize the ConversationalAgent and provide `tools`."
|
||||
)
|
||||
return super().add_tool(tool)
|
||||
|
||||
@ -1,6 +1,10 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING, Dict, Any
|
||||
|
||||
from haystack.agents.types import Color
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.agents import Agent
|
||||
|
||||
|
||||
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
|
||||
@ -14,3 +18,39 @@ def print_text(text: str, end="", color: Optional[Color] = None) -> None:
|
||||
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
|
||||
else:
|
||||
print(text, end=end, flush=True)
|
||||
|
||||
|
||||
def react_parameter_resolver(query: str, agent: "Agent", agent_step: AgentStep, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
A parameter resolver for ReAct-based agents that returns the query, the tool names, the tool names
|
||||
with descriptions, and the transcript (internal monologue).
|
||||
"""
|
||||
return {
|
||||
"query": query,
|
||||
"tool_names": agent.tm.get_tool_names(),
|
||||
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
|
||||
"transcript": agent_step.transcript,
|
||||
}
|
||||
|
||||
|
||||
def agent_without_tools_parameter_resolver(query: str, agent: "Agent", **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
A parameter resolver for simple chat agents without tools that returns the query and the history.
|
||||
"""
|
||||
return {"query": query, "history": agent.memory.load()}
|
||||
|
||||
|
||||
def conversational_agent_parameter_resolver(
|
||||
query: str, agent: "Agent", agent_step: AgentStep, **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
A parameter resolver for ReAct-based conversational agent that returns the query, the tool names, the tool names
|
||||
with descriptions, the history of the conversation, and the transcript (internal monologue).
|
||||
"""
|
||||
return {
|
||||
"query": query,
|
||||
"tool_names": agent.tm.get_tool_names(),
|
||||
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
|
||||
"transcript": agent_step.transcript,
|
||||
"history": agent.memory.load(),
|
||||
}
|
||||
|
||||
@ -149,11 +149,29 @@ LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict] = {
|
||||
"Thought: Let's think step-by-step, I first need to {transcript}"
|
||||
},
|
||||
"conversational-agent": {
|
||||
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
|
||||
"prompt": "In the following conversation, a human user interacts with an AI Agent. The human user poses questions, and the AI Agent goes through several steps to provide well-informed answers.\n"
|
||||
"If the AI Agent knows the answer, the response begins with `Final Answer:` on a new line.\n"
|
||||
"If the AI Agent is uncertain or concerned that the information may be outdated or inaccurate, it must use the available tools to find the most up-to-date information. The AI has access to these tools:\n"
|
||||
"{tool_names_with_descriptions}\n"
|
||||
"The following is the previous conversation between a human and an AI:\n"
|
||||
"{history}\n"
|
||||
"AI Agent responses must start with one of the following:\n"
|
||||
"Thought: [AI Agent's reasoning process]\n"
|
||||
"Tool: [{tool_names}] (on a new line) Tool Input: [input for the selected tool WITHOUT quotation marks and on a new line] (These must always be provided together and on separate lines.)\n"
|
||||
"Final Answer: [final answer to the human user's question]\n"
|
||||
"When selecting a tool, the AI Agent must provide both the `Tool:` and `Tool Input:` pair in the same response, but on separate lines. `Observation:` marks the beginning of a tool's result, and the AI Agent trusts these results.\n"
|
||||
"The AI Agent should not ask the human user for additional information, clarification, or context.\n"
|
||||
"If the AI Agent cannot find a specific answer after exhausting available tools and approaches, it answers with Final Answer: inconclusive\n"
|
||||
"Question: {query}\n"
|
||||
"Thought:\n"
|
||||
"{transcript}\n"
|
||||
},
|
||||
"conversational-summary": {
|
||||
"prompt": "Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:"
|
||||
},
|
||||
"conversational-agent-without-tools": {
|
||||
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
|
||||
},
|
||||
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
|
||||
}
|
||||
|
||||
|
||||
@ -1,14 +1,22 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from haystack.errors import AgentError
|
||||
from haystack.agents.base import Tool
|
||||
from haystack.agents.conversational import ConversationalAgent
|
||||
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
|
||||
from test.conftest import MockPromptNode
|
||||
from haystack.nodes import PromptNode
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
|
||||
def prompt_node(mock_model):
|
||||
prompt_node = PromptNode()
|
||||
return prompt_node
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init():
|
||||
prompt_node = MockPromptNode()
|
||||
def test_init_without_tools(prompt_node):
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Test normal case
|
||||
@ -16,34 +24,82 @@ def test_init():
|
||||
assert callable(agent.prompt_parameters_resolver)
|
||||
assert agent.max_steps == 2
|
||||
assert agent.final_answer_pattern == r"^([\s\S]+)$"
|
||||
assert agent.prompt_template.name == "conversational-agent-without-tools"
|
||||
|
||||
# ConversationalAgent doesn't have tools
|
||||
assert not agent.tm.tools
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_summary_memory():
|
||||
def test_init_with_tools(prompt_node):
|
||||
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
|
||||
|
||||
# Test normal case
|
||||
assert isinstance(agent.memory, ConversationMemory)
|
||||
assert callable(agent.prompt_parameters_resolver)
|
||||
assert agent.max_steps == 5
|
||||
assert agent.final_answer_pattern == r"Final Answer\s*:\s*(.*)"
|
||||
assert agent.prompt_template.name == "conversational-agent"
|
||||
assert agent.has_tool("ExampleTool")
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_summary_memory(prompt_node):
|
||||
# Test with summary memory
|
||||
prompt_node = MockPromptNode()
|
||||
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
|
||||
assert isinstance(agent.memory, ConversationSummaryMemory)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_no_memory():
|
||||
prompt_node = MockPromptNode()
|
||||
def test_init_with_no_memory(prompt_node):
|
||||
# Test with no memory
|
||||
agent = ConversationalAgent(prompt_node, memory=NoMemory())
|
||||
assert isinstance(agent.memory, NoMemory)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run():
|
||||
prompt_node = MockPromptNode()
|
||||
def test_init_with_custom_max_steps(prompt_node):
|
||||
# Test with custom max step
|
||||
agent = ConversationalAgent(prompt_node, max_steps=8)
|
||||
assert agent.max_steps == 8
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_init_with_custom_prompt_template(prompt_node):
|
||||
# Test with custom prompt template
|
||||
agent = ConversationalAgent(prompt_node, prompt_template="translation")
|
||||
assert agent.prompt_template.name == "translation"
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_run(prompt_node):
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
|
||||
# Mock the Agent run method
|
||||
result = agent.run("query")
|
||||
agent.run = MagicMock(return_value="Hello")
|
||||
assert agent.run("query") == "Hello"
|
||||
agent.run.assert_called_once_with("query")
|
||||
|
||||
# empty answer
|
||||
assert result["answers"][0].answer == ""
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_tool(prompt_node):
|
||||
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
|
||||
# ConversationalAgent has tools
|
||||
assert len(agent.tm.tools) == 1
|
||||
|
||||
# and add more tools if ConversationalAgent is initialized with tools
|
||||
agent.add_tool(Tool("AnotherTool", lambda x: x, description="Example tool"))
|
||||
assert len(agent.tm.tools) == 2
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_tool_not_allowed(prompt_node):
|
||||
agent = ConversationalAgent(prompt_node)
|
||||
# ConversationalAgent has no tools
|
||||
assert not agent.tm.tools
|
||||
|
||||
# and can't add tools when a ConversationalAgent is initialized without tools
|
||||
with pytest.raises(
|
||||
AgentError, match="You cannot add tools after initializing the ConversationalAgent without any tools."
|
||||
):
|
||||
agent.add_tool(Tool("ExampleTool", lambda x: x, description="Example tool"))
|
||||
|
||||
@ -17,6 +17,14 @@ def tools_manager():
|
||||
return ToolsManager(tools=tools)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_using_callable_as_tool():
|
||||
# test that we can also pass a callable as a tool
|
||||
tool_input = "Haystack"
|
||||
tool = Tool(name="ToolA", pipeline_or_node=lambda x: x + x, description="Tool A Description")
|
||||
assert tool.run(tool_input) == tool_input + tool_input
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
def test_add_tool(tools_manager):
|
||||
new_tool = Tool(name="ToolC", pipeline_or_node=mock.Mock(), description="Tool C Description")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user