mirror of
https://github.com/deepset-ai/haystack.git
synced 2025-06-26 22:00:13 +00:00
wip: adding breakpoints to agent
This commit is contained in:
parent
8e21c501df
commit
c798024d5b
@ -3,7 +3,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from haystack import component, default_from_dict, default_to_dict, logging, tracing
|
||||
from haystack.components.generators.chat.types import ChatGenerator
|
||||
@ -24,6 +28,72 @@ from .state.state_utils import merge_lists
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentBreakpointException(Exception):
|
||||
"""
|
||||
Exception raised when an agent breakpoint is triggered.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
component: Optional[str] = None,
|
||||
state: Optional[Dict[str, Any]] = None,
|
||||
results: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.component = component
|
||||
self.state = state
|
||||
self.results = results
|
||||
|
||||
|
||||
class AgentInvalidResumeStateError(Exception):
|
||||
"""
|
||||
Exception raised when an agent is resumed from an invalid state.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def _serialize_agent_data(value: Any) -> Any:
|
||||
"""
|
||||
Serializes agent data so it can be saved to a file.
|
||||
"""
|
||||
if hasattr(value, "to_dict") and callable(getattr(value, "to_dict")):
|
||||
serialized_value = value.to_dict()
|
||||
serialized_value["_type"] = value.__class__.__name__
|
||||
return serialized_value
|
||||
elif hasattr(value, "__dict__"):
|
||||
return {"_type": value.__class__.__name__, "attributes": value.__dict__}
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_agent_data(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_agent_data(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _deserialize_agent_data(value: Any) -> Any:
|
||||
"""
|
||||
Deserializes agent data loaded from a file.
|
||||
"""
|
||||
if not value or isinstance(value, (str, int, float, bool)):
|
||||
return value
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(isinstance(i, (str, int, float, bool)) for i in value):
|
||||
return value
|
||||
return [_deserialize_agent_data(i) for i in value]
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "_type" in value:
|
||||
type_name = value.pop("_type")
|
||||
if type_name == "State":
|
||||
return State.from_dict(value)
|
||||
return {k: _deserialize_agent_data(v) for k, v in value.items()}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@component
|
||||
class Agent:
|
||||
"""
|
||||
@ -220,8 +290,234 @@ class Agent:
|
||||
},
|
||||
)
|
||||
|
||||
def _validate_breakpoints(self, breakpoints: Set[Tuple[str, Optional[int]]]) -> Set[Tuple[str, int]]:
|
||||
"""
|
||||
Validates the breakpoints passed to the agent.
|
||||
|
||||
Valid breakpoint components are:
|
||||
- "chat_generator": Breaks before calling the chat generator
|
||||
- "tool_invoker": Breaks before calling any tool
|
||||
- "tool_invoker:tool_name": Breaks before calling a specific tool
|
||||
|
||||
:param breakpoints: Set of tuples of component names and visit counts at which the agent should stop.
|
||||
:returns: Set of valid breakpoints.
|
||||
"""
|
||||
processed_breakpoints: Set[Tuple[str, int]] = set()
|
||||
valid_components = {"chat_generator", "tool_invoker"}
|
||||
|
||||
# Add tool-specific breakpoints to valid components
|
||||
for tool in self.tools:
|
||||
valid_components.add(f"tool_invoker:{tool.name}")
|
||||
|
||||
for break_point in breakpoints:
|
||||
component_name = break_point[0]
|
||||
if component_name not in valid_components:
|
||||
raise ValueError(
|
||||
f"Breakpoint '{component_name}' is not a valid component. "
|
||||
f"Valid components are: {sorted(valid_components)}"
|
||||
)
|
||||
valid_breakpoint: Tuple[str, int] = (component_name, 0 if break_point[1] is None else break_point[1])
|
||||
processed_breakpoints.add(valid_breakpoint)
|
||||
return processed_breakpoints
|
||||
|
||||
@staticmethod
|
||||
def save_state(
|
||||
state: State,
|
||||
component_visits: Dict[str, int],
|
||||
generator_inputs: Dict[str, Any],
|
||||
component_name: str,
|
||||
original_messages: List[ChatMessage],
|
||||
original_kwargs: Dict[str, Any],
|
||||
debug_path: Optional[Union[str, Path]] = None,
|
||||
callback_fun: Optional[Callable[..., Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Save the current agent state to a JSON file or return as a dictionary.
|
||||
|
||||
:param state: Current agent State object
|
||||
:param component_visits: Dictionary tracking component visit counts
|
||||
:param generator_inputs: Current generator inputs
|
||||
:param component_name: Name of the component where breakpoint was triggered
|
||||
:param original_messages: Original input messages
|
||||
:param original_kwargs: Original input kwargs
|
||||
:param debug_path: Optional path to save the state file
|
||||
:param callback_fun: Optional callback function to call with the state
|
||||
:returns: The saved state dictionary
|
||||
"""
|
||||
dt = datetime.now()
|
||||
|
||||
print(component_visits)
|
||||
|
||||
# For tool-specific breakpoints, extract the base component name for visit count lookup
|
||||
base_component_name = component_name.split(":")[0] if ":" in component_name else component_name
|
||||
|
||||
agent_state = {
|
||||
"input_data": _serialize_agent_data({"messages": original_messages, "kwargs": original_kwargs}),
|
||||
"timestamp": dt.isoformat(),
|
||||
"breakpoint": {"component": component_name, "visits": component_visits[base_component_name]},
|
||||
"agent_state": {
|
||||
"state": _serialize_agent_data(state),
|
||||
"component_visits": component_visits,
|
||||
"generator_inputs": _serialize_agent_data(generator_inputs),
|
||||
},
|
||||
}
|
||||
|
||||
if not debug_path:
|
||||
return agent_state
|
||||
|
||||
if isinstance(debug_path, str):
|
||||
debug_path = Path(debug_path)
|
||||
if not isinstance(debug_path, Path):
|
||||
raise ValueError("Debug path must be a string or a Path object.")
|
||||
debug_path.mkdir(exist_ok=True)
|
||||
file_name = Path(f"agent_{component_name}_{dt.strftime('%Y_%m_%d_%H_%M_%S')}.json")
|
||||
|
||||
try:
|
||||
with open(debug_path / file_name, "w") as f_out:
|
||||
json.dump(agent_state, f_out, indent=2)
|
||||
logger.info(f"Agent state saved at: {file_name}")
|
||||
|
||||
if callback_fun is not None:
|
||||
callback_fun(agent_state)
|
||||
|
||||
return agent_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save agent state: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def load_state(file_path: Union[str, Path]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load a saved agent state.
|
||||
|
||||
:param file_path: Path to the state file
|
||||
:returns: Dict containing the loaded state
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
state = json.load(f)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise json.JSONDecodeError(f"Invalid JSON file {file_path}: {str(e)}", e.doc, e.pos)
|
||||
except IOError as e:
|
||||
raise IOError(f"Error reading {file_path}: {str(e)}")
|
||||
|
||||
try:
|
||||
Agent._validate_resume_state(state)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid agent state from {file_path}: {str(e)}")
|
||||
|
||||
logger.info(f"Successfully loaded agent state from: {file_path}")
|
||||
return state
|
||||
|
||||
@staticmethod
|
||||
def _validate_resume_state(state: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validates the loaded agent state.
|
||||
|
||||
:param state: The state to validate
|
||||
"""
|
||||
required_top_keys = {"input_data", "breakpoint", "agent_state"}
|
||||
missing_top = required_top_keys - state.keys()
|
||||
if missing_top:
|
||||
raise ValueError(f"Invalid state file: missing required keys {missing_top}")
|
||||
|
||||
agent_state = state["agent_state"]
|
||||
required_agent_keys = {"state", "component_visits", "generator_inputs"}
|
||||
missing_agent = required_agent_keys - agent_state.keys()
|
||||
if missing_agent:
|
||||
raise ValueError(f"Invalid agent_state: missing required keys {missing_agent}")
|
||||
|
||||
logger.info("Agent resume state validated successfully.")
|
||||
|
||||
def _check_breakpoints(
|
||||
self,
|
||||
breakpoints: Set[Tuple[str, int]],
|
||||
component_name: str,
|
||||
component_visits: Dict[str, int],
|
||||
state: State,
|
||||
generator_inputs: Dict[str, Any],
|
||||
original_messages: List[ChatMessage],
|
||||
original_kwargs: Dict[str, Any],
|
||||
tool_name: Optional[str] = None,
|
||||
debug_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Check if a breakpoint should be triggered for the given component.
|
||||
|
||||
:param breakpoints: Set of breakpoints to check against
|
||||
:param component_name: Name of the component to check
|
||||
:param component_visits: Dictionary tracking component visit counts
|
||||
:param state: Current agent state
|
||||
:param generator_inputs: Current generator inputs
|
||||
:param original_messages: Original input messages
|
||||
:param original_kwargs: Original input kwargs
|
||||
:param tool_name: Optional tool name for tool-specific breakpoints
|
||||
:param debug_path: Optional path to save debug state
|
||||
"""
|
||||
# Check general component breakpoints
|
||||
matching_breakpoints = [bp for bp in breakpoints if bp[0] == component_name]
|
||||
|
||||
# Check tool-specific breakpoints if tool_name is provided
|
||||
if tool_name and component_name == "tool_invoker":
|
||||
tool_specific_name = f"tool_invoker:{tool_name}"
|
||||
tool_breakpoints = [bp for bp in breakpoints if bp[0] == tool_specific_name]
|
||||
matching_breakpoints.extend(tool_breakpoints)
|
||||
|
||||
for bp in matching_breakpoints:
|
||||
visit_count = bp[1]
|
||||
if visit_count == component_visits[component_name]:
|
||||
msg = f"Breaking at component {bp[0]} visit count {component_visits[component_name]}"
|
||||
logger.info(msg)
|
||||
saved_state = self.save_state(
|
||||
state=state,
|
||||
component_visits=component_visits,
|
||||
generator_inputs=generator_inputs,
|
||||
component_name=bp[0],
|
||||
original_messages=original_messages,
|
||||
original_kwargs=original_kwargs,
|
||||
debug_path=debug_path,
|
||||
)
|
||||
raise AgentBreakpointException(msg, component=bp[0], state=saved_state)
|
||||
|
||||
def inject_resume_state(
|
||||
self, resume_state: Dict[str, Any]
|
||||
) -> Tuple[State, Dict[str, int], Dict[str, Any], List[ChatMessage]]:
|
||||
"""
|
||||
Inject resume state into the agent for continuing execution.
|
||||
|
||||
:param resume_state: The saved state to resume from
|
||||
:returns: Tuple of (state, component_visits, generator_inputs, messages)
|
||||
"""
|
||||
if not resume_state:
|
||||
raise AgentInvalidResumeStateError("Cannot inject resume state: resume_state is None")
|
||||
|
||||
self._validate_resume_state(resume_state)
|
||||
agent_state_data = resume_state["agent_state"]
|
||||
state = _deserialize_agent_data(agent_state_data["state"])
|
||||
component_visits = agent_state_data["component_visits"]
|
||||
generator_inputs = _deserialize_agent_data(agent_state_data["generator_inputs"])
|
||||
messages = state.get("messages")
|
||||
|
||||
logger.info(
|
||||
f"Resuming agent from {resume_state['breakpoint']['component']} "
|
||||
f"visit count {resume_state['breakpoint']['visits']}"
|
||||
)
|
||||
|
||||
return state, component_visits, generator_inputs, messages
|
||||
|
||||
def run(
|
||||
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None,
|
||||
resume_state: Optional[Dict[str, Any]] = None,
|
||||
debug_path: Optional[Union[str, Path]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Process messages and execute tools until an exit condition is met.
|
||||
@ -230,6 +526,11 @@ class Agent:
|
||||
If a list of dictionaries is provided, each dictionary will be converted to a ChatMessage object.
|
||||
:param streaming_callback: A callback that will be invoked when a response is streamed from the LLM.
|
||||
The same callback can be configured to emit tool results when a tool is called.
|
||||
:param breakpoints: Set of tuples of component names and visit counts at which the agent should break execution.
|
||||
Valid components are "chat_generator", "tool_invoker", or "tool_invoker:tool_name" for specific tools.
|
||||
If the visit count is not given, it defaults to 0 (break on first visit).
|
||||
:param resume_state: A dictionary containing the state of a previously saved agent execution.
|
||||
:param debug_path: Path to the directory where the agent state should be saved when breakpoints are hit.
|
||||
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
||||
The keys must match the schema defined in the Agent's `state_schema`.
|
||||
:returns:
|
||||
@ -237,65 +538,127 @@ class Agent:
|
||||
- "messages": List of all messages exchanged during the agent's run.
|
||||
- "last_message": The last message exchanged during the agent's run.
|
||||
- Any additional keys defined in the `state_schema`.
|
||||
:raises AgentBreakpointException: When a breakpoint is triggered.
|
||||
"""
|
||||
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
|
||||
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run()'.")
|
||||
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
if breakpoints and resume_state:
|
||||
logger.warning("Breakpoints cannot be provided when resuming an agent. All breakpoints will be ignored.")
|
||||
|
||||
state = State(schema=self.state_schema, data=kwargs)
|
||||
state.set("messages", messages)
|
||||
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
|
||||
# Validate breakpoints if provided
|
||||
validated_breakpoints = self._validate_breakpoints(breakpoints) if breakpoints else None
|
||||
|
||||
# Store original inputs for potential resume
|
||||
original_messages = deepcopy(messages)
|
||||
original_kwargs = deepcopy(kwargs)
|
||||
|
||||
# Handle resume state or initialize fresh state
|
||||
if resume_state:
|
||||
# Restore state from resume data
|
||||
agent_state_data = resume_state["agent_state"]
|
||||
state = _deserialize_agent_data(agent_state_data["state"])
|
||||
component_visits = agent_state_data["component_visits"]
|
||||
generator_inputs = _deserialize_agent_data(agent_state_data["generator_inputs"])
|
||||
messages = state.get("messages")
|
||||
logger.info(
|
||||
f"Resuming agent from {resume_state['breakpoint']['component']} "
|
||||
f"visit count {resume_state['breakpoint']['visits']}"
|
||||
)
|
||||
else:
|
||||
if self.system_prompt is not None:
|
||||
messages = [ChatMessage.from_system(self.system_prompt)] + messages
|
||||
|
||||
state = State(schema=self.state_schema, data=kwargs)
|
||||
state.set("messages", messages)
|
||||
component_visits = dict.fromkeys(["chat_generator", "tool_invoker"], 0)
|
||||
|
||||
streaming_callback = select_streaming_callback(
|
||||
init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
|
||||
)
|
||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
|
||||
if not resume_state:
|
||||
generator_inputs = self._prepare_generator_inputs(streaming_callback=streaming_callback)
|
||||
|
||||
with self._create_agent_span() as span:
|
||||
span.set_content_tag(
|
||||
"haystack.agent.input",
|
||||
_deepcopy_with_exceptions({"messages": messages, "streaming_callback": streaming_callback, **kwargs}),
|
||||
)
|
||||
counter = 0
|
||||
while counter < self.max_agent_steps:
|
||||
# 1. Call the ChatGenerator
|
||||
result = Pipeline._run_component(
|
||||
component_name="chat_generator",
|
||||
component={"instance": self.chat_generator},
|
||||
inputs={"messages": messages, **generator_inputs},
|
||||
component_visits=component_visits,
|
||||
parent_span=span,
|
||||
)
|
||||
llm_messages = result["replies"]
|
||||
state.set("messages", llm_messages)
|
||||
try:
|
||||
while counter < self.max_agent_steps:
|
||||
# Check breakpoint before calling ChatGenerator
|
||||
if validated_breakpoints and not resume_state:
|
||||
self._check_breakpoints(
|
||||
breakpoints=validated_breakpoints,
|
||||
component_name="chat_generator",
|
||||
component_visits=component_visits,
|
||||
state=state,
|
||||
generator_inputs=generator_inputs,
|
||||
original_messages=original_messages,
|
||||
original_kwargs=original_kwargs,
|
||||
debug_path=debug_path,
|
||||
)
|
||||
|
||||
# 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
|
||||
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
|
||||
# 1. Call the ChatGenerator
|
||||
result = Pipeline._run_component(
|
||||
component_name="chat_generator",
|
||||
component={"instance": self.chat_generator},
|
||||
inputs={"messages": messages, **generator_inputs},
|
||||
component_visits=component_visits,
|
||||
parent_span=span,
|
||||
)
|
||||
llm_messages = result["replies"]
|
||||
state.set("messages", llm_messages)
|
||||
|
||||
# 2. Check if any of the LLM responses contain a tool call or if the LLM is not using tools
|
||||
if not any(msg.tool_call for msg in llm_messages) or self._tool_invoker is None:
|
||||
counter += 1
|
||||
break
|
||||
|
||||
# Check breakpoint before calling ToolInvoker
|
||||
if validated_breakpoints and not resume_state:
|
||||
# Check for tool-specific breakpoints
|
||||
tool_names = [msg.tool_call.tool_name for msg in llm_messages if msg.tool_call]
|
||||
for tool_name in tool_names:
|
||||
self._check_breakpoints(
|
||||
breakpoints=validated_breakpoints,
|
||||
component_name="tool_invoker",
|
||||
component_visits=component_visits,
|
||||
state=state,
|
||||
generator_inputs=generator_inputs,
|
||||
original_messages=original_messages,
|
||||
original_kwargs=original_kwargs,
|
||||
tool_name=tool_name,
|
||||
debug_path=debug_path,
|
||||
)
|
||||
|
||||
# 3. Call the ToolInvoker
|
||||
# We only send the messages from the LLM to the tool invoker
|
||||
tool_invoker_result = Pipeline._run_component(
|
||||
component_name="tool_invoker",
|
||||
component={"instance": self._tool_invoker},
|
||||
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
|
||||
component_visits=component_visits,
|
||||
parent_span=span,
|
||||
)
|
||||
tool_messages = tool_invoker_result["tool_messages"]
|
||||
state = tool_invoker_result["state"]
|
||||
state.set("messages", tool_messages)
|
||||
|
||||
# 4. Check if any LLM message's tool call name matches an exit condition
|
||||
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
||||
counter += 1
|
||||
break
|
||||
|
||||
# 5. Fetch the combined messages and send them back to the LLM
|
||||
messages = state.get("messages")
|
||||
counter += 1
|
||||
break
|
||||
|
||||
# 3. Call the ToolInvoker
|
||||
# We only send the messages from the LLM to the tool invoker
|
||||
tool_invoker_result = Pipeline._run_component(
|
||||
component_name="tool_invoker",
|
||||
component={"instance": self._tool_invoker},
|
||||
inputs={"messages": llm_messages, "state": state, "streaming_callback": streaming_callback},
|
||||
component_visits=component_visits,
|
||||
parent_span=span,
|
||||
)
|
||||
tool_messages = tool_invoker_result["tool_messages"]
|
||||
state = tool_invoker_result["state"]
|
||||
state.set("messages", tool_messages)
|
||||
|
||||
# 4. Check if any LLM message's tool call name matches an exit condition
|
||||
if self.exit_conditions != ["text"] and self._check_exit_conditions(llm_messages, tool_messages):
|
||||
counter += 1
|
||||
break
|
||||
|
||||
# 5. Fetch the combined messages and send them back to the LLM
|
||||
messages = state.get("messages")
|
||||
counter += 1
|
||||
except AgentBreakpointException as e:
|
||||
# Add current agent state to the exception and re-raise
|
||||
e.results = {**state.data}
|
||||
raise
|
||||
|
||||
if counter >= self.max_agent_steps:
|
||||
logger.warning(
|
||||
@ -312,7 +675,13 @@ class Agent:
|
||||
return result
|
||||
|
||||
async def run_async(
|
||||
self, messages: List[ChatMessage], streaming_callback: Optional[StreamingCallbackT] = None, **kwargs: Any
|
||||
self,
|
||||
messages: List[ChatMessage],
|
||||
streaming_callback: Optional[StreamingCallbackT] = None,
|
||||
breakpoints: Optional[Set[Tuple[str, Optional[int]]]] = None,
|
||||
resume_state: Optional[Dict[str, Any]] = None,
|
||||
debug_path: Optional[Union[str, Path]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Asynchronously process messages and execute tools until the exit condition is met.
|
||||
@ -325,6 +694,10 @@ class Agent:
|
||||
:param streaming_callback: An asynchronous callback that will be invoked when a response
|
||||
is streamed from the LLM. The same callback can be configured to emit tool results
|
||||
when a tool is called.
|
||||
:param breakpoints: Set of tuples of component names and visit counts at which the agent should break execution.
|
||||
Valid components are "chat_generator", "tool_invoker", or "tool_invoker:tool_name" for specific tools.
|
||||
:param resume_state: A dictionary containing the state of a previously saved agent execution.
|
||||
:param debug_path: Path to the directory where the agent state should be saved when breakpoints are hit.
|
||||
:param kwargs: Additional data to pass to the State schema used by the Agent.
|
||||
The keys must match the schema defined in the Agent's `state_schema`.
|
||||
:returns:
|
||||
@ -332,6 +705,7 @@ class Agent:
|
||||
- "messages": List of all messages exchanged during the agent's run.
|
||||
- "last_message": The last message exchanged during the agent's run.
|
||||
- Any additional keys defined in the `state_schema`.
|
||||
:raises AgentBreakpointException: When a breakpoint is triggered.
|
||||
"""
|
||||
if not self._is_warmed_up and hasattr(self.chat_generator, "warm_up"):
|
||||
raise RuntimeError("The component Agent wasn't warmed up. Run 'warm_up()' before calling 'run_async()'.")
|
||||
|
Loading…
x
Reference in New Issue
Block a user