feat:Add agent event callbacks (#4491)

* Implement agent callbacks with events

* Fix mypy errors

* Fix prompt_params assignment

* PR review fixes

---------

Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-03-27 10:06:11 +02:00 committed by GitHub
parent 2a2226d63e
commit c99b58100d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 15 deletions

View File

@ -4,10 +4,15 @@ import logging
import re import re
from typing import List, Optional, Union, Dict, Any from typing import List, Optional, Union, Dict, Any
from events import Events
from haystack import Pipeline, BaseComponent, Answer, Document from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.agents.agent_step import AgentStep from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color
from haystack.agents.utils import print_text
from haystack.errors import AgentError from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.nodes.prompt.providers import TokenStreamingHandler
from haystack.pipelines import ( from haystack.pipelines import (
BaseStandardPipeline, BaseStandardPipeline,
ExtractiveQAPipeline, ExtractiveQAPipeline,
@ -57,7 +62,8 @@ class Tool:
RetrieverQuestionGenerationPipeline, RetrieverQuestionGenerationPipeline,
], ],
description: str, description: str,
output_variable: Optional[str] = "results", output_variable: str = "results",
logging_color: Color = Color.YELLOW,
): ):
if re.search(r"\W", name): if re.search(r"\W", name):
raise ValueError( raise ValueError(
@ -68,6 +74,7 @@ class Tool:
self.pipeline_or_node = pipeline_or_node self.pipeline_or_node = pipeline_or_node
self.description = description self.description = description
self.output_variable = output_variable self.output_variable = output_variable
self.logging_color = logging_color
def run(self, tool_input: str, params: Optional[dict] = None) -> str: def run(self, tool_input: str, params: Optional[dict] = None) -> str:
# We can only pass params to pipelines but not to nodes # We can only pass params to pipelines but not to nodes
@ -144,6 +151,17 @@ class Agent:
text the Agent generated. text the Agent generated.
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated. :param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
""" """
self.callback_manager = Events(
(
"on_tool_start",
"on_tool_finish",
"on_tool_error",
"on_agent_start",
"on_agent_step",
"on_agent_finish",
"on_new_token",
)
)
self.prompt_node = prompt_node self.prompt_node = prompt_node
self.prompt_template = ( self.prompt_template = (
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
@ -156,8 +174,31 @@ class Agent:
self.max_steps = max_steps self.max_steps = max_steps
self.tool_pattern = tool_pattern self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern self.final_answer_pattern = final_answer_pattern
self.add_default_logging_callbacks()
send_custom_event(event=f"{type(self).__name__} initialized") send_custom_event(event=f"{type(self).__name__} initialized")
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
def on_tool_finish(
tool_output: str,
color: Optional[Color] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
print_text(observation_prefix) # type: ignore
print_text(tool_output, color=color)
print_text(f"\n{llm_prefix}")
def on_agent_start(**kwargs: Any) -> None:
agent_name = kwargs.pop("name", "react")
print_text(f"\nAgent {agent_name} started with {kwargs}\n")
self.callback_manager.on_tool_finish += on_tool_finish
self.callback_manager.on_agent_start += on_agent_start
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
def add_tool(self, tool: Tool): def add_tool(self, tool: Tool):
""" """
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name. Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
@ -215,11 +256,13 @@ class Agent:
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final " f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
f"answer. It was set to {max_steps}." f"answer. It was set to {max_steps}."
) )
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
agent_step = self._create_first_step(query, max_steps) agent_step = self._create_first_step(query, max_steps)
while not agent_step.is_last(): try:
agent_step = self._step(agent_step, params) while not agent_step.is_last():
agent_step = self._step(agent_step, params)
finally:
self.callback_manager.on_agent_finish(agent_step)
return agent_step.final_answer(query=query) return agent_step.final_answer(query=query)
def _create_first_step(self, query: str, max_steps: int = 10): def _create_first_step(self, query: str, max_steps: int = 10):
@ -233,11 +276,21 @@ class Agent:
) )
def _step(self, current_step: AgentStep, params: Optional[dict] = None): def _step(self, current_step: AgentStep, params: Optional[dict] = None):
cm = self.callback_manager
class AgentTokenStreamingHandler(TokenStreamingHandler):
def __call__(self, token_received, **kwargs) -> str:
cm.on_new_token(token_received, **kwargs)
return token_received
# plan next step using the LLM # plan next step using the LLM
prompt_node_response = self.prompt_node(current_step.prepare_prompt()) prompt_node_response = self.prompt_node(
current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler()
)
# from the LLM response, create the next step # from the LLM response, create the next step
next_step = current_step.create_next_step(prompt_node_response) next_step = current_step.create_next_step(prompt_node_response)
self.callback_manager.on_agent_step(next_step)
# run the tool selected by the LLM # run the tool selected by the LLM
observation = self._run_tool(next_step, params) if not next_step.is_last() else None observation = self._run_tool(next_step, params) if not next_step.is_last() else None
@ -286,7 +339,18 @@ class Agent:
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent." "Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
f"Agent Step::\n{next_step}" f"Agent Step::\n{next_step}"
) )
return self.tools[tool_name].run(tool_input, params) tool_result: str = ""
tool: Tool = self.tools[tool_name]
try:
self.callback_manager.on_tool_start(tool_input, tool=tool)
tool_result = tool.run(tool_input, params)
self.callback_manager.on_tool_finish(
tool_result, observation_prefix="Observation: ", llm_prefix="Thought: ", color=tool.logging_color
)
except Exception as e:
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
raise e
return tool_result
def _get_initial_transcript(self, query: str): def _get_initial_transcript(self, query: str):
""" """

13
haystack/agents/types.py Normal file
View File

@ -0,0 +1,13 @@
from enum import Enum
class Color(Enum):
BLACK = "\033[30m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
RESET = "\x1b[0m"

16
haystack/agents/utils.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Optional
from haystack.agents.types import Color
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
"""
Print text with optional color.
:param text: Text to print.
:param end: End character to use (defaults to "").
:param color: Color to print text in (defaults to None).
"""
if color:
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
else:
print(text, end=end, flush=True)

View File

@ -71,15 +71,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter. :param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
""" """
super().__init__() super().__init__()
if not prompt_params: if prompt_params:
self.prompt_params = prompt_params
else:
# Define the regex pattern to match the strings after the $ character # Define the regex pattern to match the strings after the $ character
pattern = r"\$([a-zA-Z0-9_]+)" pattern = r"\$([a-zA-Z0-9_]+)"
prompt_params = re.findall(pattern, prompt_text) self.prompt_params = re.findall(pattern, prompt_text)
if prompt_text.count("$") != len(prompt_params): if prompt_text.count("$") != len(self.prompt_params):
raise ValueError( raise ValueError(
f"The number of parameters in prompt text {prompt_text} for prompt template {name} " f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
f"does not match the number of specified parameters {prompt_params}." f"does not match the number of specified parameters {self.prompt_params}."
) )
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes # use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
@ -87,16 +89,15 @@ class PromptTemplate(BasePromptTemplate, ABC):
t = Template(prompt_text) t = Template(prompt_text)
try: try:
t.substitute(**{param: "" for param in prompt_params}) t.substitute(**{param: "" for param in self.prompt_params})
except KeyError as e: except KeyError as e:
raise ValueError( raise ValueError(
f"Invalid parameter {e} in prompt text " f"Invalid parameter {e} in prompt text "
f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}" f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
) )
self.name = name self.name = name
self.prompt_text = prompt_text self.prompt_text = prompt_text
self.prompt_params = prompt_params
def prepare(self, *args, **kwargs) -> Dict[str, Any]: def prepare(self, *args, **kwargs) -> Dict[str, Any]:
""" """

View File

@ -99,7 +99,10 @@ dependencies = [
"jsonschema", "jsonschema",
# Preview # Preview
"canals" "canals",
# Agent events
"events"
] ]
[project.optional-dependencies] [project.optional-dependencies]