mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-25 08:04:49 +00:00
feat:Add agent event callbacks (#4491)
* Implement agent callbacks with events * Fix mypy errors * Fix prompt_params assignment * PR review fixes --------- Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
2a2226d63e
commit
c99b58100d
@ -4,10 +4,15 @@ import logging
|
||||
import re
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
from events import Events
|
||||
|
||||
from haystack import Pipeline, BaseComponent, Answer, Document
|
||||
from haystack.agents.agent_step import AgentStep
|
||||
from haystack.agents.types import Color
|
||||
from haystack.agents.utils import print_text
|
||||
from haystack.errors import AgentError
|
||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||
from haystack.nodes.prompt.providers import TokenStreamingHandler
|
||||
from haystack.pipelines import (
|
||||
BaseStandardPipeline,
|
||||
ExtractiveQAPipeline,
|
||||
@ -57,7 +62,8 @@ class Tool:
|
||||
RetrieverQuestionGenerationPipeline,
|
||||
],
|
||||
description: str,
|
||||
output_variable: Optional[str] = "results",
|
||||
output_variable: str = "results",
|
||||
logging_color: Color = Color.YELLOW,
|
||||
):
|
||||
if re.search(r"\W", name):
|
||||
raise ValueError(
|
||||
@ -68,6 +74,7 @@ class Tool:
|
||||
self.pipeline_or_node = pipeline_or_node
|
||||
self.description = description
|
||||
self.output_variable = output_variable
|
||||
self.logging_color = logging_color
|
||||
|
||||
def run(self, tool_input: str, params: Optional[dict] = None) -> str:
|
||||
# We can only pass params to pipelines but not to nodes
|
||||
@ -144,6 +151,17 @@ class Agent:
|
||||
text the Agent generated.
|
||||
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
|
||||
"""
|
||||
self.callback_manager = Events(
|
||||
(
|
||||
"on_tool_start",
|
||||
"on_tool_finish",
|
||||
"on_tool_error",
|
||||
"on_agent_start",
|
||||
"on_agent_step",
|
||||
"on_agent_finish",
|
||||
"on_new_token",
|
||||
)
|
||||
)
|
||||
self.prompt_node = prompt_node
|
||||
self.prompt_template = (
|
||||
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
|
||||
@ -156,8 +174,31 @@ class Agent:
|
||||
self.max_steps = max_steps
|
||||
self.tool_pattern = tool_pattern
|
||||
self.final_answer_pattern = final_answer_pattern
|
||||
|
||||
self.add_default_logging_callbacks()
|
||||
|
||||
send_custom_event(event=f"{type(self).__name__} initialized")
|
||||
|
||||
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
|
||||
def on_tool_finish(
|
||||
tool_output: str,
|
||||
color: Optional[Color] = None,
|
||||
observation_prefix: Optional[str] = None,
|
||||
llm_prefix: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
print_text(observation_prefix) # type: ignore
|
||||
print_text(tool_output, color=color)
|
||||
print_text(f"\n{llm_prefix}")
|
||||
|
||||
def on_agent_start(**kwargs: Any) -> None:
|
||||
agent_name = kwargs.pop("name", "react")
|
||||
print_text(f"\nAgent {agent_name} started with {kwargs}\n")
|
||||
|
||||
self.callback_manager.on_tool_finish += on_tool_finish
|
||||
self.callback_manager.on_agent_start += on_agent_start
|
||||
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
|
||||
|
||||
def add_tool(self, tool: Tool):
|
||||
"""
|
||||
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
|
||||
@ -215,11 +256,13 @@ class Agent:
|
||||
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
|
||||
f"answer. It was set to {max_steps}."
|
||||
)
|
||||
|
||||
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
|
||||
agent_step = self._create_first_step(query, max_steps)
|
||||
while not agent_step.is_last():
|
||||
agent_step = self._step(agent_step, params)
|
||||
|
||||
try:
|
||||
while not agent_step.is_last():
|
||||
agent_step = self._step(agent_step, params)
|
||||
finally:
|
||||
self.callback_manager.on_agent_finish(agent_step)
|
||||
return agent_step.final_answer(query=query)
|
||||
|
||||
def _create_first_step(self, query: str, max_steps: int = 10):
|
||||
@ -233,11 +276,21 @@ class Agent:
|
||||
)
|
||||
|
||||
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
|
||||
cm = self.callback_manager
|
||||
|
||||
class AgentTokenStreamingHandler(TokenStreamingHandler):
|
||||
def __call__(self, token_received, **kwargs) -> str:
|
||||
cm.on_new_token(token_received, **kwargs)
|
||||
return token_received
|
||||
|
||||
# plan next step using the LLM
|
||||
prompt_node_response = self.prompt_node(current_step.prepare_prompt())
|
||||
prompt_node_response = self.prompt_node(
|
||||
current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler()
|
||||
)
|
||||
|
||||
# from the LLM response, create the next step
|
||||
next_step = current_step.create_next_step(prompt_node_response)
|
||||
self.callback_manager.on_agent_step(next_step)
|
||||
|
||||
# run the tool selected by the LLM
|
||||
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
|
||||
@ -286,7 +339,18 @@ class Agent:
|
||||
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
|
||||
f"Agent Step::\n{next_step}"
|
||||
)
|
||||
return self.tools[tool_name].run(tool_input, params)
|
||||
tool_result: str = ""
|
||||
tool: Tool = self.tools[tool_name]
|
||||
try:
|
||||
self.callback_manager.on_tool_start(tool_input, tool=tool)
|
||||
tool_result = tool.run(tool_input, params)
|
||||
self.callback_manager.on_tool_finish(
|
||||
tool_result, observation_prefix="Observation: ", llm_prefix="Thought: ", color=tool.logging_color
|
||||
)
|
||||
except Exception as e:
|
||||
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
|
||||
raise e
|
||||
return tool_result
|
||||
|
||||
def _get_initial_transcript(self, query: str):
|
||||
"""
|
||||
|
13
haystack/agents/types.py
Normal file
13
haystack/agents/types.py
Normal file
@ -0,0 +1,13 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Color(Enum):
|
||||
BLACK = "\033[30m"
|
||||
RED = "\033[31m"
|
||||
GREEN = "\033[32m"
|
||||
YELLOW = "\033[33m"
|
||||
BLUE = "\033[34m"
|
||||
MAGENTA = "\033[35m"
|
||||
CYAN = "\033[36m"
|
||||
WHITE = "\033[37m"
|
||||
RESET = "\x1b[0m"
|
16
haystack/agents/utils.py
Normal file
16
haystack/agents/utils.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Optional
|
||||
|
||||
from haystack.agents.types import Color
|
||||
|
||||
|
||||
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
|
||||
"""
|
||||
Print text with optional color.
|
||||
:param text: Text to print.
|
||||
:param end: End character to use (defaults to "").
|
||||
:param color: Color to print text in (defaults to None).
|
||||
"""
|
||||
if color:
|
||||
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
|
||||
else:
|
||||
print(text, end=end, flush=True)
|
@ -71,15 +71,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
|
||||
"""
|
||||
super().__init__()
|
||||
if not prompt_params:
|
||||
if prompt_params:
|
||||
self.prompt_params = prompt_params
|
||||
else:
|
||||
# Define the regex pattern to match the strings after the $ character
|
||||
pattern = r"\$([a-zA-Z0-9_]+)"
|
||||
prompt_params = re.findall(pattern, prompt_text)
|
||||
self.prompt_params = re.findall(pattern, prompt_text)
|
||||
|
||||
if prompt_text.count("$") != len(prompt_params):
|
||||
if prompt_text.count("$") != len(self.prompt_params):
|
||||
raise ValueError(
|
||||
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
|
||||
f"does not match the number of specified parameters {prompt_params}."
|
||||
f"does not match the number of specified parameters {self.prompt_params}."
|
||||
)
|
||||
|
||||
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
||||
@ -87,16 +89,15 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
||||
|
||||
t = Template(prompt_text)
|
||||
try:
|
||||
t.substitute(**{param: "" for param in prompt_params})
|
||||
t.substitute(**{param: "" for param in self.prompt_params})
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
f"Invalid parameter {e} in prompt text "
|
||||
f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}"
|
||||
f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
|
||||
)
|
||||
|
||||
self.name = name
|
||||
self.prompt_text = prompt_text
|
||||
self.prompt_params = prompt_params
|
||||
|
||||
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
|
@ -99,7 +99,10 @@ dependencies = [
|
||||
"jsonschema",
|
||||
|
||||
# Preview
|
||||
"canals"
|
||||
"canals",
|
||||
|
||||
# Agent events
|
||||
"events"
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
Loading…
x
Reference in New Issue
Block a user