feat: Update ConversationalAgent (#5065)

* feat: Update ConversationalAgent

* Add Tools
* Add test
* Change default params

* fix tests

* Fix circular import error
* Update conversational-agent prompt
* Add conversational-agent-without-tools to legacy list

* Add warning to add tools to conversational agent

* Add callable tools

* Add example script

* Fix linter errors

* Update ConversationalAgent depending on the existance of tools

* Initialize the base Agent with different arguments when there's tool
* Inject memory to the prompt in both cases, update prompts accordingly

* Override the add_tools method to prevent adding tools to ConversationalAgent without tools

* Update test

* Fix linter error

* Remove unused import

* Update docstrings and api reference

* Fix imports and doc string code snippet

* docstrings update

* Update conversational.py

* Mock PromptNode

* Prevent circular import error

* Add max_steps to the ConversationalAgent

* Update resolver description

* Add prompt_template as parameter

* Change docstring

---------

Co-authored-by: Darja Fokina <daria.f93@gmail.com>
This commit is contained in:
Bilge Yücel 2023-06-20 13:09:21 +03:00 committed by GitHub
parent 30fdf2b5df
commit 6a1b6b1ae3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 281 additions and 68 deletions

View File

@ -1,7 +1,7 @@
loaders:
- type: python
search_path: [../../../haystack/agents]
modules: ['base']
modules: ['base', 'conversational']
ignore_when_discovered: ['__init__']
processors:
- type: filter

View File

@ -1,17 +1,66 @@
import os
from haystack.agents.base import Tool
from haystack.agents.conversational import ConversationalAgent
from haystack.nodes import PromptNode
from haystack.agents.memory import ConversationSummaryMemory
from haystack.nodes import PromptNode, WebRetriever, PromptTemplate
from haystack.pipelines import WebQAPipeline
from haystack.agents.types import Color
pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
agent = ConversationalAgent(pn)
search_api_key = os.environ.get("SEARCH_API_KEY")
if not search_api_key:
raise ValueError("Please set the SEARCH_API_KEY environment variable")
openai_api_key = os.environ.get("OPENAI_API_KEY")
if not openai_api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable")
while True:
user_input = input("Human (type 'exit' or 'quit' to quit, 'memory' for agent's memory): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
if user_input.lower() == "memory":
print("\nMemory:\n", agent.memory.load())
else:
assistant_response = agent.run(user_input)
print("\nAssistant:", assistant_response)
web_prompt = """
Synthesize a comprehensive answer from the following most relevant paragraphs and the given question.
Provide a clear and concise answer, no longer than 10-20 words.
\n\n Paragraphs: {documents} \n\n Question: {query} \n\n Answer:
"""
web_prompt_node = PromptNode(
"gpt-3.5-turbo", default_prompt_template=PromptTemplate(prompt=web_prompt), api_key=openai_api_key
)
web_retriever = WebRetriever(api_key=search_api_key, top_search_results=3, mode="snippets")
pipeline = WebQAPipeline(retriever=web_retriever, prompt_node=web_prompt_node)
web_qa_tool = Tool(
name="Search",
pipeline_or_node=pipeline,
description="useful for when you need to Google questions if you cannot find answers in the the previous conversation",
output_variable="results",
logging_color=Color.MAGENTA,
)
conversational_agent_prompt_node = PromptNode(
"gpt-3.5-turbo",
api_key=openai_api_key,
max_length=256,
stop_words=["Observation:"],
model_kwargs={"temperature": 0.5, "top_p": 0.9},
)
memory = ConversationSummaryMemory(conversational_agent_prompt_node, summary_frequency=2)
conversational_agent = ConversationalAgent(
prompt_node=conversational_agent_prompt_node, tools=[web_qa_tool], memory=memory
)
test = False
if test:
questions = [
"Why was Jamie Foxx recently hospitalized?",
"Where was he hospitalized?",
"What movie was he filming at the time?",
"Who is Jamie's female co-star in the movie he was filing at that time?",
"Tell me more about her, who is her partner?",
]
for question in questions:
conversational_agent.run(question)
else:
while True:
user_input = input("\nHuman (type 'exit' or 'quit' to quit): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
response = conversational_agent.run(user_input)

View File

@ -13,7 +13,7 @@ from haystack.agents.memory import Memory, NoMemory
from haystack.telemetry import send_event
from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color, AgentTokenStreamingHandler
from haystack.agents.utils import print_text
from haystack.agents.utils import print_text, react_parameter_resolver
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.pipelines import (
BaseStandardPipeline,
@ -63,6 +63,7 @@ class Tool:
TranslationWrapperPipeline,
RetrieverQuestionGenerationPipeline,
WebQAPipeline,
Callable[[Any], str],
],
description: str,
output_variable: str = "results",
@ -85,6 +86,8 @@ class Tool:
result = self.pipeline_or_node.run(query=tool_input, params=params)
elif isinstance(self.pipeline_or_node, BaseRetriever):
result = self.pipeline_or_node.run(query=tool_input, root_node="Query")
elif callable(self.pipeline_or_node):
result = self.pipeline_or_node(tool_input)
else:
result = self.pipeline_or_node.run(query=tool_input)
return self._process_result(result)
@ -266,14 +269,6 @@ class Agent:
f"Prompt template '{prompt_template}' not found. Please check the spelling of the template name."
)
self.prompt_template = resolved_prompt_template
react_parameter_resolver: Callable[
[str, Agent, AgentStep, Dict[str, Any]], Dict[str, Any]
] = lambda query, agent, agent_step, **kwargs: {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
}
self.prompt_parameters_resolver = (
prompt_parameters_resolver if prompt_parameters_resolver else react_parameter_resolver
)

View File

@ -1,63 +1,110 @@
from typing import Optional, Callable, Union
from typing import Optional, List, Union
import logging
from haystack.agents import Agent
from haystack.errors import AgentError
from haystack.agents.base import Tool, ToolsManager, Agent
from haystack.agents.memory import Memory, ConversationMemory
from haystack.nodes import PromptNode, PromptTemplate
from haystack.agents.utils import conversational_agent_parameter_resolver, agent_without_tools_parameter_resolver
logger = logging.getLogger(__name__)
class ConversationalAgent(Agent):
"""
A conversational agent that can be used to build a conversational chat applications.
A ConversationalAgent is an extension of the Agent class that enables the use of tools with several default parameters.
ConversationalAgent can manage a set of tools and seamlessly integrate them into the conversation.
If no tools are provided, the agent will be initialized to have a basic chat application.
Here is an example of how you can create a simple chat application:
```
Here is an example how you can create a chat application with tools:
```python
import os
from haystack.agents.base import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory
from haystack.agents.conversational import ConversationalAgent
from haystack.nodes import PromptNode
from haystack.agents.base import ToolsManager, Tool
pn = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
agent = ConversationalAgent(pn, memory=ConversationSummaryMemory(pn))
# Initialize a PromptNode and a ToolsManager with the desired tools
prompt_node = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
tools = [Tool(name="ExampleTool", pipeline_or_node=example_tool_node)]
# Create the ConversationalAgent instance
agent = ConversationalAgent(prompt_node, tools=tools)
# Use the agent in a chat application
while True:
user_input = input("Human (type 'exit' or 'quit' to quit): ")
if user_input.lower() == "exit" or user_input.lower() == "quit":
break
elif user_input.lower() == "memory":
print("\nMemory:\n", agent.memory.load())
else:
assistant_response = agent.run(user_input)
print("\nAssistant:", assistant_response)
```
If you don't want to have any tools in your chat app, you can create a ConversationalAgent only with a PromptNode:
```python
import os
from haystack.agents.conversational import ConversationalAgent
from haystack.nodes import PromptNode
# Initialize a PromptNode
prompt_node = PromptNode("gpt-3.5-turbo", api_key=os.environ.get("OPENAI_API_KEY"), max_length=256)
# Create the ConversationalAgent instance
agent = ConversationalAgent(prompt_node)
```
If you're looking for more customization, check out [Agent](https://docs.haystack.deepset.ai/reference/agent-api).
"""
def __init__(
self,
prompt_node: PromptNode,
prompt_template: Optional[Union[str, PromptTemplate]] = None,
tools: Optional[List[Tool]] = None,
memory: Optional[Memory] = None,
prompt_parameters_resolver: Optional[Callable] = None,
max_steps: Optional[int] = None,
):
"""
Creates a new ConversationalAgent instance
Creates a new ConversationalAgent instance.
:param prompt_node: A PromptNode used to communicate with LLM.
:param prompt_template: A string or PromptTemplate instance to use as the prompt template. If no prompt_template
is provided, the agent will use the "conversational-agent" template.
:param memory: A memory instance for storing conversation history and other relevant data, defaults to
ConversationMemory.
:param prompt_parameters_resolver: An optional callable for resolving prompt template parameters,
defaults to a callable that returns a dictionary with the query and the conversation history.
:param prompt_node: A PromptNode used by Agent to decide which tool to use and what input to provide to it
in each iteration. If there are no tools added, the model specified with PromptNode will be used for chatting.
:param prompt_template: A new PromptTemplate or the name of an existing PromptTemplate for the PromptNode. It's
used for keeping the chat history, generating thoughts and choosing tools (if provided) to answer queries. It defaults to
to "conversational-agent" if there is at least one tool provided and "conversational-agent-without-tools" otherwise.
:param tools: A list of tools to use in the Agent. Each tool must have a unique name.
:param memory: A memory object for storing conversation history and other relevant data, defaults to
ConversationMemory if no memory is provided.
:param max_steps: The number of times the Agent can run a tool +1 to let it infer it knows the final answer. It defaults to 5 if there is at least one tool provided and 2 otherwise.
"""
super().__init__(
prompt_node=prompt_node,
prompt_template=prompt_template or "conversational-agent",
max_steps=2,
memory=memory if memory else ConversationMemory(),
prompt_parameters_resolver=prompt_parameters_resolver
if prompt_parameters_resolver
else lambda query, agent, **kwargs: {"query": query, "history": agent.memory.load()},
final_answer_pattern=r"^([\s\S]+)$",
)
if tools:
super().__init__(
prompt_node=prompt_node,
memory=memory if memory else ConversationMemory(),
tools_manager=ToolsManager(tools=tools),
max_steps=max_steps if max_steps else 5,
prompt_template=prompt_template if prompt_template else "conversational-agent",
final_answer_pattern=r"Final Answer\s*:\s*(.*)",
prompt_parameters_resolver=conversational_agent_parameter_resolver,
)
else:
logger.warning("ConversationalAgent is created without tools")
super().__init__(
prompt_node=prompt_node,
memory=memory if memory else ConversationMemory(),
max_steps=max_steps if max_steps else 2,
prompt_template=prompt_template if prompt_template else "conversational-agent-without-tools",
final_answer_pattern=r"^([\s\S]+)$",
prompt_parameters_resolver=agent_without_tools_parameter_resolver,
)
def add_tool(self, tool: Tool):
if len(self.tm.tools) == 0:
raise AgentError(
"You cannot add tools after initializing the ConversationalAgent without any tools. If you want to add tools, reinitailize the ConversationalAgent and provide `tools`."
)
return super().add_tool(tool)

View File

@ -1,6 +1,10 @@
from typing import Optional
from typing import Optional, TYPE_CHECKING, Dict, Any
from haystack.agents.types import Color
from haystack.agents.agent_step import AgentStep
if TYPE_CHECKING:
from haystack.agents import Agent
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
@ -14,3 +18,39 @@ def print_text(text: str, end="", color: Optional[Color] = None) -> None:
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
else:
print(text, end=end, flush=True)
def react_parameter_resolver(query: str, agent: "Agent", agent_step: AgentStep, **kwargs) -> Dict[str, Any]:
"""
A parameter resolver for ReAct-based agents that returns the query, the tool names, the tool names
with descriptions, and the transcript (internal monologue).
"""
return {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
}
def agent_without_tools_parameter_resolver(query: str, agent: "Agent", **kwargs) -> Dict[str, Any]:
"""
A parameter resolver for simple chat agents without tools that returns the query and the history.
"""
return {"query": query, "history": agent.memory.load()}
def conversational_agent_parameter_resolver(
query: str, agent: "Agent", agent_step: AgentStep, **kwargs
) -> Dict[str, Any]:
"""
A parameter resolver for ReAct-based conversational agent that returns the query, the tool names, the tool names
with descriptions, the history of the conversation, and the transcript (internal monologue).
"""
return {
"query": query,
"tool_names": agent.tm.get_tool_names(),
"tool_names_with_descriptions": agent.tm.get_tool_names_with_descriptions(),
"transcript": agent_step.transcript,
"history": agent.memory.load(),
}

View File

@ -149,11 +149,29 @@ LEGACY_DEFAULT_TEMPLATES: Dict[str, Dict] = {
"Thought: Let's think step-by-step, I first need to {transcript}"
},
"conversational-agent": {
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
"prompt": "In the following conversation, a human user interacts with an AI Agent. The human user poses questions, and the AI Agent goes through several steps to provide well-informed answers.\n"
"If the AI Agent knows the answer, the response begins with `Final Answer:` on a new line.\n"
"If the AI Agent is uncertain or concerned that the information may be outdated or inaccurate, it must use the available tools to find the most up-to-date information. The AI has access to these tools:\n"
"{tool_names_with_descriptions}\n"
"The following is the previous conversation between a human and an AI:\n"
"{history}\n"
"AI Agent responses must start with one of the following:\n"
"Thought: [AI Agent's reasoning process]\n"
"Tool: [{tool_names}] (on a new line) Tool Input: [input for the selected tool WITHOUT quotation marks and on a new line] (These must always be provided together and on separate lines.)\n"
"Final Answer: [final answer to the human user's question]\n"
"When selecting a tool, the AI Agent must provide both the `Tool:` and `Tool Input:` pair in the same response, but on separate lines. `Observation:` marks the beginning of a tool's result, and the AI Agent trusts these results.\n"
"The AI Agent should not ask the human user for additional information, clarification, or context.\n"
"If the AI Agent cannot find a specific answer after exhausting available tools and approaches, it answers with Final Answer: inconclusive\n"
"Question: {query}\n"
"Thought:\n"
"{transcript}\n"
},
"conversational-summary": {
"prompt": "Condense the following chat transcript by shortening and summarizing the content without losing important information:\n{chat_transcript}\nCondensed Transcript:"
},
"conversational-agent-without-tools": {
"prompt": "The following is a conversation between a human and an AI.\n{history}\nHuman: {query}\nAI:"
},
# DO NOT ADD ANY NEW TEMPLATE IN HERE!
}

View File

@ -1,14 +1,22 @@
import pytest
from unittest.mock import MagicMock, Mock
from unittest.mock import MagicMock, patch
from haystack.errors import AgentError
from haystack.agents.base import Tool
from haystack.agents.conversational import ConversationalAgent
from haystack.agents.memory import ConversationSummaryMemory, ConversationMemory, NoMemory
from test.conftest import MockPromptNode
from haystack.nodes import PromptNode
@pytest.fixture
@patch("haystack.nodes.prompt.prompt_node.PromptModel")
def prompt_node(mock_model):
prompt_node = PromptNode()
return prompt_node
@pytest.mark.unit
def test_init():
prompt_node = MockPromptNode()
def test_init_without_tools(prompt_node):
agent = ConversationalAgent(prompt_node)
# Test normal case
@ -16,34 +24,82 @@ def test_init():
assert callable(agent.prompt_parameters_resolver)
assert agent.max_steps == 2
assert agent.final_answer_pattern == r"^([\s\S]+)$"
assert agent.prompt_template.name == "conversational-agent-without-tools"
# ConversationalAgent doesn't have tools
assert not agent.tm.tools
@pytest.mark.unit
def test_init_with_summary_memory():
def test_init_with_tools(prompt_node):
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
# Test normal case
assert isinstance(agent.memory, ConversationMemory)
assert callable(agent.prompt_parameters_resolver)
assert agent.max_steps == 5
assert agent.final_answer_pattern == r"Final Answer\s*:\s*(.*)"
assert agent.prompt_template.name == "conversational-agent"
assert agent.has_tool("ExampleTool")
@pytest.mark.unit
def test_init_with_summary_memory(prompt_node):
# Test with summary memory
prompt_node = MockPromptNode()
agent = ConversationalAgent(prompt_node, memory=ConversationSummaryMemory(prompt_node))
assert isinstance(agent.memory, ConversationSummaryMemory)
@pytest.mark.unit
def test_init_with_no_memory():
prompt_node = MockPromptNode()
def test_init_with_no_memory(prompt_node):
# Test with no memory
agent = ConversationalAgent(prompt_node, memory=NoMemory())
assert isinstance(agent.memory, NoMemory)
@pytest.mark.unit
def test_run():
prompt_node = MockPromptNode()
def test_init_with_custom_max_steps(prompt_node):
# Test with custom max step
agent = ConversationalAgent(prompt_node, max_steps=8)
assert agent.max_steps == 8
@pytest.mark.unit
def test_init_with_custom_prompt_template(prompt_node):
# Test with custom prompt template
agent = ConversationalAgent(prompt_node, prompt_template="translation")
assert agent.prompt_template.name == "translation"
@pytest.mark.unit
def test_run(prompt_node):
agent = ConversationalAgent(prompt_node)
# Mock the Agent run method
result = agent.run("query")
agent.run = MagicMock(return_value="Hello")
assert agent.run("query") == "Hello"
agent.run.assert_called_once_with("query")
# empty answer
assert result["answers"][0].answer == ""
@pytest.mark.unit
def test_add_tool(prompt_node):
agent = ConversationalAgent(prompt_node, tools=[Tool("ExampleTool", lambda x: x, description="Example tool")])
# ConversationalAgent has tools
assert len(agent.tm.tools) == 1
# and add more tools if ConversationalAgent is initialized with tools
agent.add_tool(Tool("AnotherTool", lambda x: x, description="Example tool"))
assert len(agent.tm.tools) == 2
@pytest.mark.unit
def test_add_tool_not_allowed(prompt_node):
agent = ConversationalAgent(prompt_node)
# ConversationalAgent has no tools
assert not agent.tm.tools
# and can't add tools when a ConversationalAgent is initialized without tools
with pytest.raises(
AgentError, match="You cannot add tools after initializing the ConversationalAgent without any tools."
):
agent.add_tool(Tool("ExampleTool", lambda x: x, description="Example tool"))

View File

@ -17,6 +17,14 @@ def tools_manager():
return ToolsManager(tools=tools)
@pytest.mark.unit
def test_using_callable_as_tool():
# test that we can also pass a callable as a tool
tool_input = "Haystack"
tool = Tool(name="ToolA", pipeline_or_node=lambda x: x + x, description="Tool A Description")
assert tool.run(tool_input) == tool_input + tool_input
@pytest.mark.unit
def test_add_tool(tools_manager):
new_tool = Tool(name="ToolC", pipeline_or_node=mock.Mock(), description="Tool C Description")