feat:Add agent event callbacks (#4491)

* Implement agent callbacks with events

* Fix mypy errors

* Fix prompt_params assignment

* PR review fixes

---------

Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
Vladimir Blagojevic 2023-03-27 10:06:11 +02:00 committed by GitHub
parent 2a2226d63e
commit c99b58100d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 112 additions and 15 deletions

View File

@ -4,10 +4,15 @@ import logging
import re
from typing import List, Optional, Union, Dict, Any
from events import Events
from haystack import Pipeline, BaseComponent, Answer, Document
from haystack.agents.agent_step import AgentStep
from haystack.agents.types import Color
from haystack.agents.utils import print_text
from haystack.errors import AgentError
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
from haystack.nodes.prompt.providers import TokenStreamingHandler
from haystack.pipelines import (
BaseStandardPipeline,
ExtractiveQAPipeline,
@ -57,7 +62,8 @@ class Tool:
RetrieverQuestionGenerationPipeline,
],
description: str,
output_variable: Optional[str] = "results",
output_variable: str = "results",
logging_color: Color = Color.YELLOW,
):
if re.search(r"\W", name):
raise ValueError(
@ -68,6 +74,7 @@ class Tool:
self.pipeline_or_node = pipeline_or_node
self.description = description
self.output_variable = output_variable
self.logging_color = logging_color
def run(self, tool_input: str, params: Optional[dict] = None) -> str:
# We can only pass params to pipelines but not to nodes
@ -144,6 +151,17 @@ class Agent:
text the Agent generated.
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
"""
self.callback_manager = Events(
(
"on_tool_start",
"on_tool_finish",
"on_tool_error",
"on_agent_start",
"on_agent_step",
"on_agent_finish",
"on_new_token",
)
)
self.prompt_node = prompt_node
self.prompt_template = (
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
@ -156,8 +174,31 @@ class Agent:
self.max_steps = max_steps
self.tool_pattern = tool_pattern
self.final_answer_pattern = final_answer_pattern
self.add_default_logging_callbacks()
send_custom_event(event=f"{type(self).__name__} initialized")
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
def on_tool_finish(
tool_output: str,
color: Optional[Color] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
print_text(observation_prefix) # type: ignore
print_text(tool_output, color=color)
print_text(f"\n{llm_prefix}")
def on_agent_start(**kwargs: Any) -> None:
agent_name = kwargs.pop("name", "react")
print_text(f"\nAgent {agent_name} started with {kwargs}\n")
self.callback_manager.on_tool_finish += on_tool_finish
self.callback_manager.on_agent_start += on_agent_start
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
def add_tool(self, tool: Tool):
"""
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
@ -215,11 +256,13 @@ class Agent:
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
f"answer. It was set to {max_steps}."
)
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
agent_step = self._create_first_step(query, max_steps)
while not agent_step.is_last():
agent_step = self._step(agent_step, params)
try:
while not agent_step.is_last():
agent_step = self._step(agent_step, params)
finally:
self.callback_manager.on_agent_finish(agent_step)
return agent_step.final_answer(query=query)
def _create_first_step(self, query: str, max_steps: int = 10):
@ -233,11 +276,21 @@ class Agent:
)
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
cm = self.callback_manager
class AgentTokenStreamingHandler(TokenStreamingHandler):
def __call__(self, token_received, **kwargs) -> str:
cm.on_new_token(token_received, **kwargs)
return token_received
# plan next step using the LLM
prompt_node_response = self.prompt_node(current_step.prepare_prompt())
prompt_node_response = self.prompt_node(
current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler()
)
# from the LLM response, create the next step
next_step = current_step.create_next_step(prompt_node_response)
self.callback_manager.on_agent_step(next_step)
# run the tool selected by the LLM
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
@ -286,7 +339,18 @@ class Agent:
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
f"Agent Step::\n{next_step}"
)
return self.tools[tool_name].run(tool_input, params)
tool_result: str = ""
tool: Tool = self.tools[tool_name]
try:
self.callback_manager.on_tool_start(tool_input, tool=tool)
tool_result = tool.run(tool_input, params)
self.callback_manager.on_tool_finish(
tool_result, observation_prefix="Observation: ", llm_prefix="Thought: ", color=tool.logging_color
)
except Exception as e:
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
raise e
return tool_result
def _get_initial_transcript(self, query: str):
"""

13
haystack/agents/types.py Normal file
View File

@ -0,0 +1,13 @@
from enum import Enum
class Color(Enum):
BLACK = "\033[30m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
RESET = "\x1b[0m"

16
haystack/agents/utils.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Optional
from haystack.agents.types import Color
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
"""
Print text with optional color.
:param text: Text to print.
:param end: End character to use (defaults to "").
:param color: Color to print text in (defaults to None).
"""
if color:
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
else:
print(text, end=end, flush=True)

View File

@ -71,15 +71,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
"""
super().__init__()
if not prompt_params:
if prompt_params:
self.prompt_params = prompt_params
else:
# Define the regex pattern to match the strings after the $ character
pattern = r"\$([a-zA-Z0-9_]+)"
prompt_params = re.findall(pattern, prompt_text)
self.prompt_params = re.findall(pattern, prompt_text)
if prompt_text.count("$") != len(prompt_params):
if prompt_text.count("$") != len(self.prompt_params):
raise ValueError(
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
f"does not match the number of specified parameters {prompt_params}."
f"does not match the number of specified parameters {self.prompt_params}."
)
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
@ -87,16 +89,15 @@ class PromptTemplate(BasePromptTemplate, ABC):
t = Template(prompt_text)
try:
t.substitute(**{param: "" for param in prompt_params})
t.substitute(**{param: "" for param in self.prompt_params})
except KeyError as e:
raise ValueError(
f"Invalid parameter {e} in prompt text "
f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}"
f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
)
self.name = name
self.prompt_text = prompt_text
self.prompt_params = prompt_params
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
"""

View File

@ -99,7 +99,10 @@ dependencies = [
"jsonschema",
# Preview
"canals"
"canals",
# Agent events
"events"
]
[project.optional-dependencies]