mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 03:10:04 +00:00
Load and Save state in AgentChat (#4436)
1. convert dataclass types to pydantic basemodel 2. add save_state and load_state for ChatAgent 3. state types for AgentChat --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
fef06fdc8a
commit
777f2abbd7
@ -127,7 +127,7 @@ async def main() -> None:
|
||||
lambda: Coder(
|
||||
model_client=client,
|
||||
system_messages=[
|
||||
SystemMessage("""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
||||
SystemMessage(content="""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
||||
|
||||
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.components.models import (
|
||||
@ -29,6 +29,7 @@ from ..messages import (
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ..state import AssistantAgentState
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@ -49,6 +50,12 @@ class Handoff(HandoffBase):
|
||||
class AssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
```{note}
|
||||
The assistant agent is not thread-safe or coroutine-safe.
|
||||
It should not be shared between multiple tasks or coroutines, and it should
|
||||
not call its methods concurrently.
|
||||
```
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
@ -224,6 +231,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
self._is_running = False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
@ -327,3 +335,13 @@ class AssistantAgent(BaseChatAgent):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Reset the assistant agent to its initialization state."""
|
||||
self._model_context.clear()
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the assistant agent."""
|
||||
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the assistant agent"""
|
||||
assistant_agent_state = AssistantAgentState.model_validate(state)
|
||||
self._model_context.clear()
|
||||
self._model_context.extend(assistant_agent_state.llm_messages)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ..state import BaseState
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
@ -117,3 +118,11 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Export state. Default implementation for stateless agents."""
|
||||
return BaseState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||
BaseState.model_validate(state)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
@ -54,3 +54,11 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save agent state for later restoration"""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state"""
|
||||
...
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol
|
||||
from typing import Any, Mapping, Protocol
|
||||
|
||||
from ._task import TaskRunner
|
||||
|
||||
@ -7,3 +7,11 @@ class Team(TaskRunner, Protocol):
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and all its participants to its initial state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the team."""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the team."""
|
||||
...
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
@ -23,6 +24,8 @@ class TextMessage(BaseMessage):
|
||||
content: str
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
|
||||
|
||||
class MultiModalMessage(BaseMessage):
|
||||
"""A multimodal message."""
|
||||
@ -30,6 +33,8 @@ class MultiModalMessage(BaseMessage):
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
|
||||
|
||||
class StopMessage(BaseMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
@ -37,6 +42,8 @@ class StopMessage(BaseMessage):
|
||||
content: str
|
||||
"""The content for the stop message."""
|
||||
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
|
||||
|
||||
class HandoffMessage(BaseMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
@ -47,6 +54,8 @@ class HandoffMessage(BaseMessage):
|
||||
content: str
|
||||
"""The handoff message to the target agent."""
|
||||
|
||||
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message signaling the use of tools."""
|
||||
@ -54,6 +63,8 @@ class ToolCallMessage(BaseMessage):
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
type: Literal["ToolCallMessage"] = "ToolCallMessage"
|
||||
|
||||
|
||||
class ToolCallResultMessage(BaseMessage):
|
||||
"""A message signaling the results of tool calls."""
|
||||
@ -61,12 +72,17 @@ class ToolCallResultMessage(BaseMessage):
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
||||
|
||||
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
|
||||
"""Messages for agent-to-agent communication."""
|
||||
|
||||
|
||||
AgentMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage
|
||||
AgentMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""All message types."""
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
"""State management for agents, teams and termination conditions."""
|
||||
|
||||
from ._states import (
|
||||
AssistantAgentState,
|
||||
BaseGroupChatManagerState,
|
||||
BaseState,
|
||||
ChatAgentContainerState,
|
||||
MagenticOneOrchestratorState,
|
||||
RoundRobinManagerState,
|
||||
SelectorManagerState,
|
||||
SwarmManagerState,
|
||||
TeamState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseState",
|
||||
"AssistantAgentState",
|
||||
"BaseGroupChatManagerState",
|
||||
"ChatAgentContainerState",
|
||||
"RoundRobinManagerState",
|
||||
"SelectorManagerState",
|
||||
"SwarmManagerState",
|
||||
"MagenticOneOrchestratorState",
|
||||
"TeamState",
|
||||
]
|
||||
@ -0,0 +1,81 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from autogen_core.components.models import (
|
||||
LLMMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
)
|
||||
|
||||
|
||||
class BaseState(BaseModel):
|
||||
"""Base class for all saveable state"""
|
||||
|
||||
type: str = Field(default="BaseState")
|
||||
version: str = Field(default="1.0.0")
|
||||
|
||||
|
||||
class AssistantAgentState(BaseState):
|
||||
"""State for an assistant agent."""
|
||||
|
||||
llm_messages: List[LLMMessage] = Field(default_factory=list)
|
||||
type: str = Field(default="AssistantAgentState")
|
||||
|
||||
|
||||
class TeamState(BaseState):
|
||||
"""State for a team of agents."""
|
||||
|
||||
agent_states: Mapping[str, Any] = Field(default_factory=dict)
|
||||
team_id: str = Field(default="")
|
||||
type: str = Field(default="TeamState")
|
||||
|
||||
|
||||
class BaseGroupChatManagerState(BaseState):
|
||||
"""Base state for all group chat managers."""
|
||||
|
||||
message_thread: List[AgentMessage] = Field(default_factory=list)
|
||||
current_turn: int = Field(default=0)
|
||||
type: str = Field(default="BaseGroupChatManagerState")
|
||||
|
||||
|
||||
class ChatAgentContainerState(BaseState):
|
||||
"""State for a container of chat agents."""
|
||||
|
||||
agent_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
message_buffer: List[ChatMessage] = Field(default_factory=list)
|
||||
type: str = Field(default="ChatAgentContainerState")
|
||||
|
||||
|
||||
class RoundRobinManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.RoundRobinGroupChat` manager."""
|
||||
|
||||
next_speaker_index: int = Field(default=0)
|
||||
type: str = Field(default="RoundRobinManagerState")
|
||||
|
||||
|
||||
class SelectorManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""
|
||||
|
||||
previous_speaker: Optional[str] = Field(default=None)
|
||||
type: str = Field(default="SelectorManagerState")
|
||||
|
||||
|
||||
class SwarmManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.Swarm` manager."""
|
||||
|
||||
current_speaker: str = Field(default="")
|
||||
type: str = Field(default="SwarmManagerState")
|
||||
|
||||
|
||||
class MagenticOneOrchestratorState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.MagneticOneGroupChat` orchestrator."""
|
||||
|
||||
task: str = Field(default="")
|
||||
facts: str = Field(default="")
|
||||
plan: str = Field(default="")
|
||||
n_rounds: int = Field(default=0)
|
||||
n_stalls: int = Field(default=0)
|
||||
type: str = Field(default="MagenticOneOrchestratorState")
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Callable, List
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
@ -20,6 +20,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
@ -493,3 +494,38 @@ class BaseGroupChat(Team, ABC):
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the group chat team."""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be saved while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
|
||||
agent_states = await self._runtime.save_state()
|
||||
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the group chat team."""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be loaded while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
|
||||
team_state = TeamState.model_validate(state)
|
||||
self._team_id = team_state.team_id
|
||||
await self._runtime.load_state(team_state.agent_states)
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import ChatAgent, Response
|
||||
from ...messages import ChatMessage
|
||||
from ...state import ChatAgentContainerState
|
||||
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
@ -75,3 +76,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
agent_state = await self._agent.save_state()
|
||||
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
container_state = ChatAgentContainerState.model_validate(state)
|
||||
self._message_buffer = list(container_state.message_buffer)
|
||||
await self._agent.load_state(container_state.agent_state)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
|
||||
from autogen_core.components.models import (
|
||||
@ -13,6 +13,7 @@ from autogen_core.components.models import (
|
||||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
@ -178,6 +179,28 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = MagenticOneOrchestratorState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
task=self._task,
|
||||
facts=self._facts,
|
||||
plan=self._plan,
|
||||
n_rounds=self._n_rounds,
|
||||
n_stalls=self._n_stalls,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||
self._message_thread = orchestrator_state.message_thread
|
||||
self._current_turn = orchestrator_state.current_turn
|
||||
self._task = orchestrator_state.task
|
||||
self._facts = orchestrator_state.facts
|
||||
self._plan = orchestrator_state.plan
|
||||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return ""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Callable, List
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
@ -38,6 +39,20 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
await self._termination_condition.reset()
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = RoundRobinManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
next_speaker_index=self._next_speaker_index,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||
self._message_thread = list(round_robin_state.message_thread)
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Dict, List, Sequence
|
||||
from typing import Any, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||
|
||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ... import TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentMessage,
|
||||
@ -16,11 +16,11 @@ from ...messages import (
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
@ -64,6 +64,20 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
await self._termination_condition.reset()
|
||||
self._previous_speaker = None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SelectorManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
previous_speaker=self._previous_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = list(selector_state.message_thread)
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
@ -121,7 +135,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
select_speaker_prompt = self._selector_prompt.format(
|
||||
roles=roles, participants=str(participants), history=history
|
||||
)
|
||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
response = await self._model_client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||
@ -77,6 +74,20 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
return self._current_speaker
|
||||
return self._current_speaker
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SwarmManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
current_speaker=self._current_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
swarm_state = SwarmManagerState.model_validate(state)
|
||||
self._message_thread = list(swarm_state.message_thread)
|
||||
self._current_turn = swarm_state.current_turn
|
||||
self._current_speaker = swarm_state.current_speaker
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat):
|
||||
"""A group chat team that selects the next speaker based on handoff message only.
|
||||
|
||||
@ -112,12 +112,12 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
tool_use_agent = AssistantAgent(
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
result = await tool_use_agent.run(task="task")
|
||||
result = await agent.run(task="task")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].models_usage is None
|
||||
@ -135,13 +135,24 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
state2 = await agent2.save_state()
|
||||
assert state == state2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -28,12 +28,15 @@ from autogen_agentchat.teams import (
|
||||
SelectorGroupChat,
|
||||
Swarm,
|
||||
)
|
||||
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
|
||||
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
||||
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core import AgentId, CancellationToken, FunctionCall
|
||||
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_ext.models import OpenAIChatCompletionClient, ReplayChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@ -217,6 +220,38 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_state() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
@ -528,6 +563,42 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result2 == result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_state() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = SelectorGroupChat(
|
||||
participants=[agent1, agent2], termination_condition=termination, model_client=model_client
|
||||
)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = SelectorGroupChat(
|
||||
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
SelectorGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
SelectorGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
@ -768,6 +839,26 @@ async def test_swarm_handoff() -> None:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test save and load.
|
||||
state = await team.save_state()
|
||||
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -16,7 +16,8 @@ from autogen_agentchat.messages import (
|
||||
from autogen_agentchat.teams import (
|
||||
MagenticOneGroupChat,
|
||||
)
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from autogen_core import AgentId, CancellationToken
|
||||
from autogen_ext.models import ReplayChatCompletionClient
|
||||
from utils import FileLogHandler
|
||||
|
||||
@ -121,6 +122,27 @@ async def test_magentic_one_group_chat_basic() -> None:
|
||||
assert result.messages[4].content == "print('Hello, world!')"
|
||||
assert result.stop_reason is not None and result.stop_reason == "Because"
|
||||
|
||||
# Test save and load.
|
||||
state = await team.save_state()
|
||||
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._task == manager_2._task # pyright: ignore
|
||||
assert manager_1._facts == manager_2._facts # pyright: ignore
|
||||
assert manager_1._plan == manager_2._plan # pyright: ignore
|
||||
assert manager_1._n_rounds == manager_2._n_rounds # pyright: ignore
|
||||
assert manager_1._n_stalls == manager_2._n_stalls # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_magentic_one_group_chat_with_stalls() -> None:
|
||||
|
||||
@ -48,6 +48,12 @@ A dynamic team that uses handoffs to pass tasks between agents.
|
||||
How to build custom agents.
|
||||
:::
|
||||
|
||||
:::{grid-item-card} {fas}`users;pst-color-primary` State Management
|
||||
:link: ./state.html
|
||||
|
||||
How to manage state in agents and teams.
|
||||
:::
|
||||
|
||||
::::
|
||||
|
||||
```{toctree}
|
||||
@ -61,4 +67,5 @@ selector-group-chat
|
||||
swarm
|
||||
termination
|
||||
custom-agents
|
||||
state
|
||||
```
|
||||
|
||||
@ -0,0 +1,299 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Managing State \n",
|
||||
"\n",
|
||||
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
|
||||
"\n",
|
||||
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Saving and Loading Agents\n",
|
||||
"\n",
|
||||
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
|
||||
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In Tanganyika's depths so wide and deep, \n",
|
||||
"Ancient secrets in still waters sleep, \n",
|
||||
"Ripples tell tales that time longs to keep. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import TextMessage\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" # api_key=\"YOUR_API_KEY\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths so wide and deep, \\nAncient secrets in still waters sleep, \\nRipples tell tales that time longs to keep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_state = await assistant_agent.save_state()\n",
|
||||
"print(agent_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The last line of the poem I wrote was: \n",
|
||||
"\"Ripples tell tales that time longs to keep.\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await new_assistant_agent.load_state(agent_state)\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await new_assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
|
||||
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving and Loading Teams \n",
|
||||
"\n",
|
||||
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
|
||||
"\n",
|
||||
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
|
||||
"\n",
|
||||
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Write a beautiful poem 3-line about lake tangayika\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"In Tanganyika's depths, where light gently weaves, \n",
|
||||
"Silver reflections dance on ancient water's face, \n",
|
||||
"Whispered stories of time in the rippling leaves. \n",
|
||||
"[Prompt tokens: 29, Completion tokens: 36]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 29\n",
|
||||
"Total completion tokens: 36\n",
|
||||
"Duration: 1.16 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define a team.\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"\n",
|
||||
"# Run the team and stream messages to the console.\n",
|
||||
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"await Console(stream)\n",
|
||||
"\n",
|
||||
"# Save the state of the agent team.\n",
|
||||
"team_state = await agent_team.save_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\n",
|
||||
"[Prompt tokens: 28, Completion tokens: 39]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 28\n",
|
||||
"Total completion tokens: 39\n",
|
||||
"Duration: 0.95 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=39), type='TextMessage', content=\"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\")], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await agent_team.reset()\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote.\n",
|
||||
"\n",
|
||||
"Note: You can serialize the state of the team to a file and load it back later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'type': 'TextMessage', 'content': 'Write a beautiful poem 3-line about lake tangayika'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 36}, 'type': 'TextMessage', 'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \"}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/c80054be-efb2-4bc7-ba0d-900962092c44': {}, 'assistant_agent/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'c80054be-efb2-4bc7-ba0d-900962092c44'}\n",
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote was: \n",
|
||||
"\"Whispered stories of time in the rippling leaves.\"\n",
|
||||
"[Prompt tokens: 88, Completion tokens: 24]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 88\n",
|
||||
"Total completion tokens: 24\n",
|
||||
"Duration: 0.79 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=88, completion_tokens=24), type='TextMessage', content='The last line of the poem I wrote was: \\n\"Whispered stories of time in the rippling leaves.\"')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(team_state)\n",
|
||||
"\n",
|
||||
"# Load team state.\n",
|
||||
"await agent_team.load_state(team_state)\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -1,283 +1,283 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# User Approval for Tool Execution using Intervention Handler\n",
|
||||
"\n",
|
||||
"This cookbook shows how to intercept the tool execution using\n",
|
||||
"an intervention hanlder, and prompt the user for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's define a simple message type that carries a string content."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a simple tool use agent that is capable of using tools through a\n",
|
||||
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
|
||||
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" description: str,\n",
|
||||
" system_messages: List[SystemMessage],\n",
|
||||
" model_client: ChatCompletionClient,\n",
|
||||
" tool_schema: List[ToolSchema],\n",
|
||||
" tool_agent_type: AgentType,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(description)\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_messages = system_messages\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
|
||||
" # Use the tool agent to execute the tools, and get the output messages.\n",
|
||||
" output_messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Extract the final response from the output messages.\n",
|
||||
" final_response = output_messages[-1].content\n",
|
||||
" assert isinstance(final_response, str)\n",
|
||||
" return Message(content=final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
|
||||
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
|
||||
"to prompt the user for permission to execute the tool.\n",
|
||||
"\n",
|
||||
"Let's create an intervention handler that intercepts the messages and prompts\n",
|
||||
"user for before allowing the tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
|
||||
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
|
||||
" if isinstance(message, FunctionCall):\n",
|
||||
" # Request user prompt for tool execution.\n",
|
||||
" user_input = input(\n",
|
||||
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
|
||||
" )\n",
|
||||
" if user_input.strip().lower() != \"y\":\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can create a runtime with the intervention handler registered."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the runtime with the intervention handler.\n",
|
||||
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this example, we will use a tool for Python code execution.\n",
|
||||
"First, we create a Docker-based command-line code executor\n",
|
||||
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
|
||||
"and then use it to instantiate a built-in Python code execution tool\n",
|
||||
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
|
||||
"that runs code in a Docker container."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the docker executor for the Python code execution tool.\n",
|
||||
"docker_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"\n",
|
||||
"# Create the Python code execution tool.\n",
|
||||
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register the agents with tools and tool schema."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_enabled_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Register agents.\n",
|
||||
"tool_agent_type = await ToolAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_executor_agent\",\n",
|
||||
" lambda: ToolAgent(\n",
|
||||
" description=\"Tool Executor Agent\",\n",
|
||||
" tools=[python_tool],\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_enabled_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" description=\"Tool Use Agent\",\n",
|
||||
" system_messages=[SystemMessage(\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[python_tool.schema],\n",
|
||||
" tool_agent_type=tool_agent_type,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
|
||||
"The intervention handler will prompt you for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The output of the code is: **Hello, World!**\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start the runtime and the docker executor.\n",
|
||||
"await docker_executor.start()\n",
|
||||
"runtime.start()\n",
|
||||
"\n",
|
||||
"# Send a task to the tool user.\n",
|
||||
"response = await runtime.send_message(\n",
|
||||
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
|
||||
")\n",
|
||||
"print(response.content)\n",
|
||||
"\n",
|
||||
"# Stop the runtime and the docker executor.\n",
|
||||
"await runtime.stop()\n",
|
||||
"await docker_executor.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# User Approval for Tool Execution using Intervention Handler\n",
|
||||
"\n",
|
||||
"This cookbook shows how to intercept the tool execution using\n",
|
||||
"an intervention hanlder, and prompt the user for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import Any, List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's define a simple message type that carries a string content."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's create a simple tool use agent that is capable of using tools through a\n",
|
||||
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
|
||||
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
|
||||
"\n",
|
||||
" def __init__(\n",
|
||||
" self,\n",
|
||||
" description: str,\n",
|
||||
" system_messages: List[SystemMessage],\n",
|
||||
" model_client: ChatCompletionClient,\n",
|
||||
" tool_schema: List[ToolSchema],\n",
|
||||
" tool_agent_type: AgentType,\n",
|
||||
" ) -> None:\n",
|
||||
" super().__init__(description)\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_messages = system_messages\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
|
||||
" # Use the tool agent to execute the tools, and get the output messages.\n",
|
||||
" output_messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Extract the final response from the output messages.\n",
|
||||
" final_response = output_messages[-1].content\n",
|
||||
" assert isinstance(final_response, str)\n",
|
||||
" return Message(content=final_response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
|
||||
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
|
||||
"to prompt the user for permission to execute the tool.\n",
|
||||
"\n",
|
||||
"Let's create an intervention handler that intercepts the messages and prompts\n",
|
||||
"user for before allowing the tool execution."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
|
||||
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
|
||||
" if isinstance(message, FunctionCall):\n",
|
||||
" # Request user prompt for tool execution.\n",
|
||||
" user_input = input(\n",
|
||||
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
|
||||
" )\n",
|
||||
" if user_input.strip().lower() != \"y\":\n",
|
||||
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
|
||||
" return message"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, we can create a runtime with the intervention handler registered."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the runtime with the intervention handler.\n",
|
||||
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"In this example, we will use a tool for Python code execution.\n",
|
||||
"First, we create a Docker-based command-line code executor\n",
|
||||
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
|
||||
"and then use it to instantiate a built-in Python code execution tool\n",
|
||||
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
|
||||
"that runs code in a Docker container."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create the docker executor for the Python code execution tool.\n",
|
||||
"docker_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"\n",
|
||||
"# Create the Python code execution tool.\n",
|
||||
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Register the agents with tools and tool schema."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_enabled_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 33,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Register agents.\n",
|
||||
"tool_agent_type = await ToolAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_executor_agent\",\n",
|
||||
" lambda: ToolAgent(\n",
|
||||
" description=\"Tool Executor Agent\",\n",
|
||||
" tools=[python_tool],\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_enabled_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" description=\"Tool Use Agent\",\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[python_tool.schema],\n",
|
||||
" tool_agent_type=tool_agent_type,\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
|
||||
"The intervention handler will prompt you for permission to execute the tool."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The output of the code is: **Hello, World!**\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start the runtime and the docker executor.\n",
|
||||
"await docker_executor.start()\n",
|
||||
"runtime.start()\n",
|
||||
"\n",
|
||||
"# Send a task to the tool user.\n",
|
||||
"response = await runtime.send_message(\n",
|
||||
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
|
||||
")\n",
|
||||
"print(response.content)\n",
|
||||
"\n",
|
||||
"# Stop the runtime and the docker executor.\n",
|
||||
"await runtime.stop()\n",
|
||||
"await docker_executor.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.9"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -158,7 +158,7 @@
|
||||
" super().__init__(description=description)\n",
|
||||
" self._group_chat_topic_type = group_chat_topic_type\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_message = SystemMessage(system_message)\n",
|
||||
" self._system_message = SystemMessage(content=system_message)\n",
|
||||
" self._chat_history: List[LLMMessage] = []\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
@ -427,7 +427,7 @@
|
||||
"Read the above conversation. Then select the next role from {participants} to play. Only return the role.\n",
|
||||
"\"\"\"\n",
|
||||
" system_message = SystemMessage(\n",
|
||||
" selector_prompt.format(\n",
|
||||
" content=selector_prompt.format(\n",
|
||||
" roles=roles,\n",
|
||||
" history=history,\n",
|
||||
" participants=str(\n",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,315 +1,315 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tools\n",
|
||||
"\n",
|
||||
"Tools are code that can be executed by an agent to perform actions. A tool\n",
|
||||
"can be a simple function such as a calculator, or an API call to a third-party service\n",
|
||||
"such as stock price lookup or weather forecast.\n",
|
||||
"In the context of AI agents, tools are designed to be executed by agents in\n",
|
||||
"response to model-generated function calls.\n",
|
||||
"\n",
|
||||
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
|
||||
"tools and utilities for creating and running custom tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Built-in Tools\n",
|
||||
"\n",
|
||||
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
|
||||
"which allows agents to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Here is how you create the tool and use it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello, world!\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"\n",
|
||||
"# Create the tool.\n",
|
||||
"code_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"await code_executor.start()\n",
|
||||
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"\n",
|
||||
"# Use the tool directly without an agent.\n",
|
||||
"code = \"print('Hello, world!')\"\n",
|
||||
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
|
||||
"print(code_execution_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
|
||||
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
|
||||
"in the local command line environment.\n",
|
||||
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
|
||||
"and provides a simple interface to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Other built-in tools will be added in the future."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Function Tools\n",
|
||||
"\n",
|
||||
"A tool can also be a simple Python function that performs a specific action.\n",
|
||||
"To create a custom function tool, you just need to create a Python function\n",
|
||||
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
|
||||
"\n",
|
||||
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
|
||||
"to inform the LLM when and how to use a given function. The description provides context\n",
|
||||
"about the function’s purpose and intended use cases, while type annotations inform the LLM about\n",
|
||||
"the expected parameters and return type.\n",
|
||||
"\n",
|
||||
"For example, a simple tool to obtain the stock price of a company might look like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"80.44429939059668\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import FunctionTool\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
|
||||
" # Returns a random stock price for demonstration purposes.\n",
|
||||
" return random.uniform(10, 200)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a function tool.\n",
|
||||
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
|
||||
"\n",
|
||||
"# Run the tool.\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
|
||||
"\n",
|
||||
"# Print the result.\n",
|
||||
"print(stock_price_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool-Equipped Agent\n",
|
||||
"\n",
|
||||
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
|
||||
"by using it in a composition pattern.\n",
|
||||
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
|
||||
"as an inner agent for executing tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||
" super().__init__(\"An agent with tools\")\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" # Create a session of messages.\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
|
||||
" # Run the caller loop to handle tool calls.\n",
|
||||
" messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Return the final response.\n",
|
||||
" assert isinstance(messages[-1].content, str)\n",
|
||||
" return Message(content=messages[-1].content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
|
||||
"to handle the interaction between the model and the tool agent.\n",
|
||||
"The core idea can be described using a simple control flow graph:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
|
||||
"and determines whether the model has generated a tool call.\n",
|
||||
"If the model has generated tool calls, then the handler sends a function call\n",
|
||||
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
|
||||
"to execute the tools,\n",
|
||||
"and then queries the model again with the results of the tool calls.\n",
|
||||
"This process continues until the model stops generating tool calls,\n",
|
||||
"at which point the final response is returned to the user.\n",
|
||||
"\n",
|
||||
"By having the tool execution logic in a separate agent,\n",
|
||||
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
|
||||
"can be observed externally and intercepted if necessary.\n",
|
||||
"\n",
|
||||
"To run the agent, we need to create a runtime and register the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_use_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a runtime.\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"# Create the tools.\n",
|
||||
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
|
||||
"# Register the agents.\n",
|
||||
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_use_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
|
||||
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
|
||||
"Let's test the agent with a question about stock price."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start processing messages.\n",
|
||||
"runtime.start()\n",
|
||||
"# Send a direct message to the tool agent.\n",
|
||||
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
|
||||
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "autogen_core",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tools\n",
|
||||
"\n",
|
||||
"Tools are code that can be executed by an agent to perform actions. A tool\n",
|
||||
"can be a simple function such as a calculator, or an API call to a third-party service\n",
|
||||
"such as stock price lookup or weather forecast.\n",
|
||||
"In the context of AI agents, tools are designed to be executed by agents in\n",
|
||||
"response to model-generated function calls.\n",
|
||||
"\n",
|
||||
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
|
||||
"tools and utilities for creating and running custom tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Built-in Tools\n",
|
||||
"\n",
|
||||
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
|
||||
"which allows agents to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Here is how you create the tool and use it."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Hello, world!\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
|
||||
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
|
||||
"\n",
|
||||
"# Create the tool.\n",
|
||||
"code_executor = DockerCommandLineCodeExecutor()\n",
|
||||
"await code_executor.start()\n",
|
||||
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"\n",
|
||||
"# Use the tool directly without an agent.\n",
|
||||
"code = \"print('Hello, world!')\"\n",
|
||||
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
|
||||
"print(code_execution_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
|
||||
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
|
||||
"in the local command line environment.\n",
|
||||
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
|
||||
"and provides a simple interface to execute Python code snippets.\n",
|
||||
"\n",
|
||||
"Other built-in tools will be added in the future."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Custom Function Tools\n",
|
||||
"\n",
|
||||
"A tool can also be a simple Python function that performs a specific action.\n",
|
||||
"To create a custom function tool, you just need to create a Python function\n",
|
||||
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
|
||||
"\n",
|
||||
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
|
||||
"to inform the LLM when and how to use a given function. The description provides context\n",
|
||||
"about the function’s purpose and intended use cases, while type annotations inform the LLM about\n",
|
||||
"the expected parameters and return type.\n",
|
||||
"\n",
|
||||
"For example, a simple tool to obtain the stock price of a company might look like this:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"80.44429939059668\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import random\n",
|
||||
"\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_core.components.tools import FunctionTool\n",
|
||||
"from typing_extensions import Annotated\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
|
||||
" # Returns a random stock price for demonstration purposes.\n",
|
||||
" return random.uniform(10, 200)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a function tool.\n",
|
||||
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
|
||||
"\n",
|
||||
"# Run the tool.\n",
|
||||
"cancellation_token = CancellationToken()\n",
|
||||
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
|
||||
"\n",
|
||||
"# Print the result.\n",
|
||||
"print(stock_price_tool.return_value_as_string(result))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Tool-Equipped Agent\n",
|
||||
"\n",
|
||||
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
|
||||
"by using it in a composition pattern.\n",
|
||||
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
|
||||
"as an inner agent for executing tools."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from dataclasses import dataclass\n",
|
||||
"from typing import List\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
|
||||
"from autogen_core.application import SingleThreadedAgentRuntime\n",
|
||||
"from autogen_core.components.models import (\n",
|
||||
" ChatCompletionClient,\n",
|
||||
" LLMMessage,\n",
|
||||
" SystemMessage,\n",
|
||||
" UserMessage,\n",
|
||||
")\n",
|
||||
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
|
||||
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"@dataclass\n",
|
||||
"class Message:\n",
|
||||
" content: str\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||
" super().__init__(\"An agent with tools\")\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
|
||||
" # Create a session of messages.\n",
|
||||
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
|
||||
" # Run the caller loop to handle tool calls.\n",
|
||||
" messages = await tool_agent_caller_loop(\n",
|
||||
" self,\n",
|
||||
" tool_agent_id=self._tool_agent_id,\n",
|
||||
" model_client=self._model_client,\n",
|
||||
" input_messages=session,\n",
|
||||
" tool_schema=self._tool_schema,\n",
|
||||
" cancellation_token=ctx.cancellation_token,\n",
|
||||
" )\n",
|
||||
" # Return the final response.\n",
|
||||
" assert isinstance(messages[-1].content, str)\n",
|
||||
" return Message(content=messages[-1].content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
|
||||
"to handle the interaction between the model and the tool agent.\n",
|
||||
"The core idea can be described using a simple control flow graph:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
|
||||
"and determines whether the model has generated a tool call.\n",
|
||||
"If the model has generated tool calls, then the handler sends a function call\n",
|
||||
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
|
||||
"to execute the tools,\n",
|
||||
"and then queries the model again with the results of the tool calls.\n",
|
||||
"This process continues until the model stops generating tool calls,\n",
|
||||
"at which point the final response is returned to the user.\n",
|
||||
"\n",
|
||||
"By having the tool execution logic in a separate agent,\n",
|
||||
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
|
||||
"can be observed externally and intercepted if necessary.\n",
|
||||
"\n",
|
||||
"To run the agent, we need to create a runtime and register the agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AgentType(type='tool_use_agent')"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Create a runtime.\n",
|
||||
"runtime = SingleThreadedAgentRuntime()\n",
|
||||
"# Create the tools.\n",
|
||||
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
|
||||
"# Register the agents.\n",
|
||||
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
|
||||
"await ToolUseAgent.register(\n",
|
||||
" runtime,\n",
|
||||
" \"tool_use_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
|
||||
" ),\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
|
||||
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
|
||||
"Let's test the agent with a question about stock price."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Start processing messages.\n",
|
||||
"runtime.start()\n",
|
||||
"# Send a direct message to the tool agent.\n",
|
||||
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
|
||||
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
|
||||
"print(response.content)\n",
|
||||
"# Stop processing messages.\n",
|
||||
"await runtime.stop()"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "autogen_core",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ Read the following conversation. Then select the next role from {participants} t
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
"""
|
||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
response = await client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = await mentioned_agents(response.content, agents)
|
||||
|
||||
@ -92,7 +92,7 @@ def convert_messages_to_llm_messages(
|
||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
result.append(converted_message_2)
|
||||
case FunctionExecutionResultMessage(_):
|
||||
case FunctionExecutionResultMessage(content=_):
|
||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||
if converted_message_3 is not None:
|
||||
result.append(converted_message_3)
|
||||
|
||||
@ -31,7 +31,7 @@ class BaseGroupChatAgent(RoutedAgent):
|
||||
super().__init__(description=description)
|
||||
self._group_chat_topic_type = group_chat_topic_type
|
||||
self._model_client = model_client
|
||||
self._system_message = SystemMessage(system_message)
|
||||
self._system_message = SystemMessage(content=system_message)
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
self._ui_config = ui_config
|
||||
self.console = Console()
|
||||
@ -126,7 +126,7 @@ Read the following conversation. Then select the next role from {participants} t
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
|
||||
"""
|
||||
system_message = SystemMessage(selector_prompt)
|
||||
system_message = SystemMessage(content=selector_prompt)
|
||||
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
|
||||
|
||||
assert isinstance(
|
||||
|
||||
@ -160,12 +160,14 @@ class SchedulingAssistantAgent(RoutedAgent):
|
||||
self._name = name
|
||||
self._model_client = model_client
|
||||
self._system_messages = [
|
||||
SystemMessage(f"""
|
||||
SystemMessage(
|
||||
content=f"""
|
||||
I am a helpful AI assistant that helps schedule meetings.
|
||||
If there are missing parameters, I will ask for them.
|
||||
|
||||
Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
]
|
||||
|
||||
@message_handler
|
||||
|
||||
@ -116,10 +116,12 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
||||
return await self._closure(self, message, ctx)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise ValueError("save_state not implemented for ClosureAgent")
|
||||
"""Closure agents do not have state. So this method always returns an empty dictionary."""
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise ValueError("load_state not implemented for ClosureAgent")
|
||||
"""Closure agents do not have state. So this method does nothing."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register_closure(
|
||||
|
||||
@ -1,42 +1,49 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from ... import FunctionCall, Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
class SystemMessage(BaseModel):
|
||||
content: str
|
||||
type: Literal["SystemMessage"] = "SystemMessage"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
class UserMessage(BaseModel):
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
type: Literal["UserMessage"] = "UserMessage"
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
type: Literal["AssistantMessage"] = "AssistantMessage"
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
|
||||
class FunctionExecutionResult(BaseModel):
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
class FunctionExecutionResultMessage(BaseModel):
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
LLMMessage = Annotated[
|
||||
Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -54,16 +61,14 @@ class TopLogprob:
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionTokenLogprob:
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: Optional[List[TopLogprob] | None] = None
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
class CreateResult(BaseModel):
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
|
||||
@ -36,9 +36,11 @@ class FileSurfer(BaseChatAgent):
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""),
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@ -78,7 +80,7 @@ class FileSurfer(BaseChatAgent):
|
||||
|
||||
except BaseException:
|
||||
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
|
||||
@ -190,7 +190,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
||||
|
||||
except BaseException:
|
||||
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
@ -712,7 +712,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
||||
for line in page_markdown.splitlines():
|
||||
message = UserMessage(
|
||||
# content=[
|
||||
prompt + buffer + line,
|
||||
content=prompt + buffer + line,
|
||||
# ag_image,
|
||||
# ],
|
||||
source=self.name,
|
||||
|
||||
@ -33,7 +33,7 @@ class LLMAgent(RoutedAgent):
|
||||
|
||||
@property
|
||||
def _fixed_message_history_type(self) -> List[SystemMessage]:
|
||||
return [SystemMessage(msg.content) for msg in self._chat_history]
|
||||
return [SystemMessage(content=msg.content) for msg in self._chat_history]
|
||||
|
||||
|
||||
@default_subscription
|
||||
|
||||
@ -21,7 +21,8 @@ class Coder(BaseWorker):
|
||||
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""You are a helpful AI assistant.
|
||||
SystemMessage(
|
||||
content="""You are a helpful AI assistant.
|
||||
Solve tasks using your coding and language skills.
|
||||
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||
@ -31,7 +32,8 @@ When using code, you must indicate the script type in the code block. The user c
|
||||
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||
Reply "TERMINATE" in the end when everything is done.""")
|
||||
Reply "TERMINATE" in the end when everything is done."""
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -23,9 +23,11 @@ class FileSurfer(BaseWorker):
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""),
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -817,7 +817,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
||||
for line in re.split(r"([\r\n]+)", page_markdown):
|
||||
message = UserMessage(
|
||||
# content=[
|
||||
prompt + buffer + line,
|
||||
content=prompt + buffer + line,
|
||||
# ag_image,
|
||||
# ],
|
||||
source=self.metadata["type"],
|
||||
|
||||
@ -47,7 +47,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
||||
It uses a ledger (implemented as a JSON generated by the LLM) to keep track of task progress and select the next agent that should speak."""
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage(ORCHESTRATOR_SYSTEM_MESSAGE),
|
||||
SystemMessage(content=ORCHESTRATOR_SYSTEM_MESSAGE),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user