mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-09-26 08:33:51 +00:00
feat:Add agent event callbacks (#4491)
* Implement agent callbacks with events * Fix mypy errors * Fix prompt_params assignment * PR review fixes --------- Co-authored-by: Silvano Cerza <silvanocerza@gmail.com>
This commit is contained in:
parent
2a2226d63e
commit
c99b58100d
@ -4,10 +4,15 @@ import logging
|
|||||||
import re
|
import re
|
||||||
from typing import List, Optional, Union, Dict, Any
|
from typing import List, Optional, Union, Dict, Any
|
||||||
|
|
||||||
|
from events import Events
|
||||||
|
|
||||||
from haystack import Pipeline, BaseComponent, Answer, Document
|
from haystack import Pipeline, BaseComponent, Answer, Document
|
||||||
from haystack.agents.agent_step import AgentStep
|
from haystack.agents.agent_step import AgentStep
|
||||||
|
from haystack.agents.types import Color
|
||||||
|
from haystack.agents.utils import print_text
|
||||||
from haystack.errors import AgentError
|
from haystack.errors import AgentError
|
||||||
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
from haystack.nodes import PromptNode, BaseRetriever, PromptTemplate
|
||||||
|
from haystack.nodes.prompt.providers import TokenStreamingHandler
|
||||||
from haystack.pipelines import (
|
from haystack.pipelines import (
|
||||||
BaseStandardPipeline,
|
BaseStandardPipeline,
|
||||||
ExtractiveQAPipeline,
|
ExtractiveQAPipeline,
|
||||||
@ -57,7 +62,8 @@ class Tool:
|
|||||||
RetrieverQuestionGenerationPipeline,
|
RetrieverQuestionGenerationPipeline,
|
||||||
],
|
],
|
||||||
description: str,
|
description: str,
|
||||||
output_variable: Optional[str] = "results",
|
output_variable: str = "results",
|
||||||
|
logging_color: Color = Color.YELLOW,
|
||||||
):
|
):
|
||||||
if re.search(r"\W", name):
|
if re.search(r"\W", name):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -68,6 +74,7 @@ class Tool:
|
|||||||
self.pipeline_or_node = pipeline_or_node
|
self.pipeline_or_node = pipeline_or_node
|
||||||
self.description = description
|
self.description = description
|
||||||
self.output_variable = output_variable
|
self.output_variable = output_variable
|
||||||
|
self.logging_color = logging_color
|
||||||
|
|
||||||
def run(self, tool_input: str, params: Optional[dict] = None) -> str:
|
def run(self, tool_input: str, params: Optional[dict] = None) -> str:
|
||||||
# We can only pass params to pipelines but not to nodes
|
# We can only pass params to pipelines but not to nodes
|
||||||
@ -144,6 +151,17 @@ class Agent:
|
|||||||
text the Agent generated.
|
text the Agent generated.
|
||||||
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
|
:param final_answer_pattern: A regular expression to extract the final answer from the text the Agent generated.
|
||||||
"""
|
"""
|
||||||
|
self.callback_manager = Events(
|
||||||
|
(
|
||||||
|
"on_tool_start",
|
||||||
|
"on_tool_finish",
|
||||||
|
"on_tool_error",
|
||||||
|
"on_agent_start",
|
||||||
|
"on_agent_step",
|
||||||
|
"on_agent_finish",
|
||||||
|
"on_new_token",
|
||||||
|
)
|
||||||
|
)
|
||||||
self.prompt_node = prompt_node
|
self.prompt_node = prompt_node
|
||||||
self.prompt_template = (
|
self.prompt_template = (
|
||||||
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
|
prompt_node.get_prompt_template(prompt_template) if isinstance(prompt_template, str) else prompt_template
|
||||||
@ -156,8 +174,31 @@ class Agent:
|
|||||||
self.max_steps = max_steps
|
self.max_steps = max_steps
|
||||||
self.tool_pattern = tool_pattern
|
self.tool_pattern = tool_pattern
|
||||||
self.final_answer_pattern = final_answer_pattern
|
self.final_answer_pattern = final_answer_pattern
|
||||||
|
|
||||||
|
self.add_default_logging_callbacks()
|
||||||
|
|
||||||
send_custom_event(event=f"{type(self).__name__} initialized")
|
send_custom_event(event=f"{type(self).__name__} initialized")
|
||||||
|
|
||||||
|
def add_default_logging_callbacks(self, agent_color: Color = Color.GREEN) -> None:
|
||||||
|
def on_tool_finish(
|
||||||
|
tool_output: str,
|
||||||
|
color: Optional[Color] = None,
|
||||||
|
observation_prefix: Optional[str] = None,
|
||||||
|
llm_prefix: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
print_text(observation_prefix) # type: ignore
|
||||||
|
print_text(tool_output, color=color)
|
||||||
|
print_text(f"\n{llm_prefix}")
|
||||||
|
|
||||||
|
def on_agent_start(**kwargs: Any) -> None:
|
||||||
|
agent_name = kwargs.pop("name", "react")
|
||||||
|
print_text(f"\nAgent {agent_name} started with {kwargs}\n")
|
||||||
|
|
||||||
|
self.callback_manager.on_tool_finish += on_tool_finish
|
||||||
|
self.callback_manager.on_agent_start += on_agent_start
|
||||||
|
self.callback_manager.on_new_token += lambda token, **kwargs: print_text(token, color=agent_color)
|
||||||
|
|
||||||
def add_tool(self, tool: Tool):
|
def add_tool(self, tool: Tool):
|
||||||
"""
|
"""
|
||||||
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
|
Add a tool to the Agent. This also updates the PromptTemplate for the Agent's PromptNode with the tool name.
|
||||||
@ -215,11 +256,13 @@ class Agent:
|
|||||||
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
|
f"max_steps must be at least 2 to let the Agent use a tool once and then infer it knows the final "
|
||||||
f"answer. It was set to {max_steps}."
|
f"answer. It was set to {max_steps}."
|
||||||
)
|
)
|
||||||
|
self.callback_manager.on_agent_start(name=self.prompt_template.name, query=query, params=params)
|
||||||
agent_step = self._create_first_step(query, max_steps)
|
agent_step = self._create_first_step(query, max_steps)
|
||||||
while not agent_step.is_last():
|
try:
|
||||||
agent_step = self._step(agent_step, params)
|
while not agent_step.is_last():
|
||||||
|
agent_step = self._step(agent_step, params)
|
||||||
|
finally:
|
||||||
|
self.callback_manager.on_agent_finish(agent_step)
|
||||||
return agent_step.final_answer(query=query)
|
return agent_step.final_answer(query=query)
|
||||||
|
|
||||||
def _create_first_step(self, query: str, max_steps: int = 10):
|
def _create_first_step(self, query: str, max_steps: int = 10):
|
||||||
@ -233,11 +276,21 @@ class Agent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
|
def _step(self, current_step: AgentStep, params: Optional[dict] = None):
|
||||||
|
cm = self.callback_manager
|
||||||
|
|
||||||
|
class AgentTokenStreamingHandler(TokenStreamingHandler):
|
||||||
|
def __call__(self, token_received, **kwargs) -> str:
|
||||||
|
cm.on_new_token(token_received, **kwargs)
|
||||||
|
return token_received
|
||||||
|
|
||||||
# plan next step using the LLM
|
# plan next step using the LLM
|
||||||
prompt_node_response = self.prompt_node(current_step.prepare_prompt())
|
prompt_node_response = self.prompt_node(
|
||||||
|
current_step.prepare_prompt(), stream_handler=AgentTokenStreamingHandler()
|
||||||
|
)
|
||||||
|
|
||||||
# from the LLM response, create the next step
|
# from the LLM response, create the next step
|
||||||
next_step = current_step.create_next_step(prompt_node_response)
|
next_step = current_step.create_next_step(prompt_node_response)
|
||||||
|
self.callback_manager.on_agent_step(next_step)
|
||||||
|
|
||||||
# run the tool selected by the LLM
|
# run the tool selected by the LLM
|
||||||
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
|
observation = self._run_tool(next_step, params) if not next_step.is_last() else None
|
||||||
@ -286,7 +339,18 @@ class Agent:
|
|||||||
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
|
"Add the tool using `add_tool()` or include it in the parameter `tools` when initializing the Agent."
|
||||||
f"Agent Step::\n{next_step}"
|
f"Agent Step::\n{next_step}"
|
||||||
)
|
)
|
||||||
return self.tools[tool_name].run(tool_input, params)
|
tool_result: str = ""
|
||||||
|
tool: Tool = self.tools[tool_name]
|
||||||
|
try:
|
||||||
|
self.callback_manager.on_tool_start(tool_input, tool=tool)
|
||||||
|
tool_result = tool.run(tool_input, params)
|
||||||
|
self.callback_manager.on_tool_finish(
|
||||||
|
tool_result, observation_prefix="Observation: ", llm_prefix="Thought: ", color=tool.logging_color
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.callback_manager.on_tool_error(e, tool=self.tools[tool_name])
|
||||||
|
raise e
|
||||||
|
return tool_result
|
||||||
|
|
||||||
def _get_initial_transcript(self, query: str):
|
def _get_initial_transcript(self, query: str):
|
||||||
"""
|
"""
|
||||||
|
13
haystack/agents/types.py
Normal file
13
haystack/agents/types.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
|
||||||
|
class Color(Enum):
|
||||||
|
BLACK = "\033[30m"
|
||||||
|
RED = "\033[31m"
|
||||||
|
GREEN = "\033[32m"
|
||||||
|
YELLOW = "\033[33m"
|
||||||
|
BLUE = "\033[34m"
|
||||||
|
MAGENTA = "\033[35m"
|
||||||
|
CYAN = "\033[36m"
|
||||||
|
WHITE = "\033[37m"
|
||||||
|
RESET = "\x1b[0m"
|
16
haystack/agents/utils.py
Normal file
16
haystack/agents/utils.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from haystack.agents.types import Color
|
||||||
|
|
||||||
|
|
||||||
|
def print_text(text: str, end="", color: Optional[Color] = None) -> None:
|
||||||
|
"""
|
||||||
|
Print text with optional color.
|
||||||
|
:param text: Text to print.
|
||||||
|
:param end: End character to use (defaults to "").
|
||||||
|
:param color: Color to print text in (defaults to None).
|
||||||
|
"""
|
||||||
|
if color:
|
||||||
|
print(f"{color.value}{text}{Color.RESET.value}", end=end, flush=True)
|
||||||
|
else:
|
||||||
|
print(text, end=end, flush=True)
|
@ -71,15 +71,17 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
|||||||
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
|
:param prompt_params: Optional parameters that need to be filled in the prompt text. If you don't specify them, they're inferred from the prompt text. Any variable in prompt text in the format `$variablename` is interpreted as a prompt parameter.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not prompt_params:
|
if prompt_params:
|
||||||
|
self.prompt_params = prompt_params
|
||||||
|
else:
|
||||||
# Define the regex pattern to match the strings after the $ character
|
# Define the regex pattern to match the strings after the $ character
|
||||||
pattern = r"\$([a-zA-Z0-9_]+)"
|
pattern = r"\$([a-zA-Z0-9_]+)"
|
||||||
prompt_params = re.findall(pattern, prompt_text)
|
self.prompt_params = re.findall(pattern, prompt_text)
|
||||||
|
|
||||||
if prompt_text.count("$") != len(prompt_params):
|
if prompt_text.count("$") != len(self.prompt_params):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
|
f"The number of parameters in prompt text {prompt_text} for prompt template {name} "
|
||||||
f"does not match the number of specified parameters {prompt_params}."
|
f"does not match the number of specified parameters {self.prompt_params}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
# use case when PromptTemplate is loaded from a YAML file, we need to start and end the prompt text with quotes
|
||||||
@ -87,16 +89,15 @@ class PromptTemplate(BasePromptTemplate, ABC):
|
|||||||
|
|
||||||
t = Template(prompt_text)
|
t = Template(prompt_text)
|
||||||
try:
|
try:
|
||||||
t.substitute(**{param: "" for param in prompt_params})
|
t.substitute(**{param: "" for param in self.prompt_params})
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid parameter {e} in prompt text "
|
f"Invalid parameter {e} in prompt text "
|
||||||
f"{prompt_text} for prompt template {name}, specified parameters are {prompt_params}"
|
f"{prompt_text} for prompt template {name}, specified parameters are {self.prompt_params}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.prompt_text = prompt_text
|
self.prompt_text = prompt_text
|
||||||
self.prompt_params = prompt_params
|
|
||||||
|
|
||||||
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
def prepare(self, *args, **kwargs) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
@ -99,7 +99,10 @@ dependencies = [
|
|||||||
"jsonschema",
|
"jsonschema",
|
||||||
|
|
||||||
# Preview
|
# Preview
|
||||||
"canals"
|
"canals",
|
||||||
|
|
||||||
|
# Agent events
|
||||||
|
"events"
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user