Load and Save state in AgentChat (#4436)

1. convert dataclass types to pydantic basemodel 
2. add save_state and load_state for ChatAgent
3. state types for AgentChat
---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Victor Dibia 2024-12-04 16:14:41 -08:00 committed by GitHub
parent fef06fdc8a
commit 777f2abbd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
39 changed files with 3684 additions and 2964 deletions

View File

@ -127,7 +127,7 @@ async def main() -> None:
lambda: Coder(
model_client=client,
system_messages=[
SystemMessage("""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
SystemMessage(content="""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
)

View File

@ -2,7 +2,7 @@ import asyncio
import json
import logging
import warnings
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
from autogen_core import CancellationToken, FunctionCall
from autogen_core.components.models import (
@ -29,6 +29,7 @@ from ..messages import (
ToolCallMessage,
ToolCallResultMessage,
)
from ..state import AssistantAgentState
from ._base_chat_agent import BaseChatAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
@ -49,6 +50,12 @@ class Handoff(HandoffBase):
class AssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
```{note}
The assistant agent is not thread-safe or coroutine-safe.
It should not be shared between multiple tasks or coroutines, and it should
not call its methods concurrently.
```
Args:
name (str): The name of the agent.
model_client (ChatCompletionClient): The model client to use for inference.
@ -224,6 +231,7 @@ class AssistantAgent(BaseChatAgent):
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
)
self._model_context: List[LLMMessage] = []
self._is_running = False
@property
def produced_message_types(self) -> List[type[ChatMessage]]:
@ -327,3 +335,13 @@ class AssistantAgent(BaseChatAgent):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
self._model_context.clear()
async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the assistant agent."""
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the assistant agent"""
assistant_agent_state = AssistantAgentState.model_validate(state)
self._model_context.clear()
self._model_context.extend(assistant_agent_state.llm_messages)

View File

@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import AsyncGenerator, List, Sequence
from typing import Any, AsyncGenerator, List, Mapping, Sequence
from autogen_core import CancellationToken
from ..base import ChatAgent, Response, TaskResult
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ..state import BaseState
class BaseChatAgent(ChatAgent, ABC):
@ -117,3 +118,11 @@ class BaseChatAgent(ChatAgent, ABC):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Resets the agent to its initialization state."""
...
async def save_state(self) -> Mapping[str, Any]:
"""Export state. Default implementation for stateless agents."""
return BaseState().model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state. Default implementation for stateless agents."""
BaseState.model_validate(state)

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
from autogen_core import CancellationToken
@ -54,3 +54,11 @@ class ChatAgent(TaskRunner, Protocol):
async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Resets the agent to its initialization state."""
...
async def save_state(self) -> Mapping[str, Any]:
"""Save agent state for later restoration"""
...
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Restore agent from saved state"""
...

View File

@ -1,4 +1,4 @@
from typing import Protocol
from typing import Any, Mapping, Protocol
from ._task import TaskRunner
@ -7,3 +7,11 @@ class Team(TaskRunner, Protocol):
async def reset(self) -> None:
"""Reset the team and all its participants to its initial state."""
...
async def save_state(self) -> Mapping[str, Any]:
"""Save the current state of the team."""
...
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the team."""
...

View File

@ -1,8 +1,9 @@
from typing import List
from typing import List, Literal
from autogen_core import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated
class BaseMessage(BaseModel):
@ -23,6 +24,8 @@ class TextMessage(BaseMessage):
content: str
"""The content of the message."""
type: Literal["TextMessage"] = "TextMessage"
class MultiModalMessage(BaseMessage):
"""A multimodal message."""
@ -30,6 +33,8 @@ class MultiModalMessage(BaseMessage):
content: List[str | Image]
"""The content of the message."""
type: Literal["MultiModalMessage"] = "MultiModalMessage"
class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""
@ -37,6 +42,8 @@ class StopMessage(BaseMessage):
content: str
"""The content for the stop message."""
type: Literal["StopMessage"] = "StopMessage"
class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""
@ -47,6 +54,8 @@ class HandoffMessage(BaseMessage):
content: str
"""The handoff message to the target agent."""
type: Literal["HandoffMessage"] = "HandoffMessage"
class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""
@ -54,6 +63,8 @@ class ToolCallMessage(BaseMessage):
content: List[FunctionCall]
"""The tool calls."""
type: Literal["ToolCallMessage"] = "ToolCallMessage"
class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls."""
@ -61,12 +72,17 @@ class ToolCallResultMessage(BaseMessage):
content: List[FunctionExecutionResult]
"""The tool call results."""
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
"""Messages for agent-to-agent communication."""
AgentMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage
AgentMessage = Annotated[
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
Field(discriminator="type"),
]
"""All message types."""

View File

@ -0,0 +1,25 @@
"""State management for agents, teams and termination conditions."""
from ._states import (
AssistantAgentState,
BaseGroupChatManagerState,
BaseState,
ChatAgentContainerState,
MagenticOneOrchestratorState,
RoundRobinManagerState,
SelectorManagerState,
SwarmManagerState,
TeamState,
)
__all__ = [
"BaseState",
"AssistantAgentState",
"BaseGroupChatManagerState",
"ChatAgentContainerState",
"RoundRobinManagerState",
"SelectorManagerState",
"SwarmManagerState",
"MagenticOneOrchestratorState",
"TeamState",
]

View File

@ -0,0 +1,81 @@
from typing import Any, List, Mapping, Optional
from autogen_core.components.models import (
LLMMessage,
)
from pydantic import BaseModel, Field
from ..messages import (
AgentMessage,
ChatMessage,
)
class BaseState(BaseModel):
"""Base class for all saveable state"""
type: str = Field(default="BaseState")
version: str = Field(default="1.0.0")
class AssistantAgentState(BaseState):
"""State for an assistant agent."""
llm_messages: List[LLMMessage] = Field(default_factory=list)
type: str = Field(default="AssistantAgentState")
class TeamState(BaseState):
"""State for a team of agents."""
agent_states: Mapping[str, Any] = Field(default_factory=dict)
team_id: str = Field(default="")
type: str = Field(default="TeamState")
class BaseGroupChatManagerState(BaseState):
"""Base state for all group chat managers."""
message_thread: List[AgentMessage] = Field(default_factory=list)
current_turn: int = Field(default=0)
type: str = Field(default="BaseGroupChatManagerState")
class ChatAgentContainerState(BaseState):
"""State for a container of chat agents."""
agent_state: Mapping[str, Any] = Field(default_factory=dict)
message_buffer: List[ChatMessage] = Field(default_factory=list)
type: str = Field(default="ChatAgentContainerState")
class RoundRobinManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.RoundRobinGroupChat` manager."""
next_speaker_index: int = Field(default=0)
type: str = Field(default="RoundRobinManagerState")
class SelectorManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""
previous_speaker: Optional[str] = Field(default=None)
type: str = Field(default="SelectorManagerState")
class SwarmManagerState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.Swarm` manager."""
current_speaker: str = Field(default="")
type: str = Field(default="SwarmManagerState")
class MagenticOneOrchestratorState(BaseGroupChatManagerState):
"""State for :class:`~autogen_agentchat.teams.MagneticOneGroupChat` orchestrator."""
task: str = Field(default="")
facts: str = Field(default="")
plan: str = Field(default="")
n_rounds: int = Field(default=0)
n_stalls: int = Field(default=0)
type: str = Field(default="MagenticOneOrchestratorState")

View File

@ -2,7 +2,7 @@ import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Callable, List
from typing import Any, AsyncGenerator, Callable, List, Mapping
from autogen_core import (
AgentId,
@ -20,6 +20,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
from ...state import TeamState
from ._chat_agent_container import ChatAgentContainer
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
from ._sequential_routed_agent import SequentialRoutedAgent
@ -493,3 +494,38 @@ class BaseGroupChat(Team, ABC):
# Indicate that the team is no longer running.
self._is_running = False
async def save_state(self) -> Mapping[str, Any]:
"""Save the state of the group chat team."""
if not self._initialized:
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
if self._is_running:
raise RuntimeError("The team cannot be saved while it is running.")
self._is_running = True
try:
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
agent_states = await self._runtime.save_state()
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
finally:
# Indicate that the team is no longer running.
self._is_running = False
async def load_state(self, state: Mapping[str, Any]) -> None:
"""Load the state of the group chat team."""
if not self._initialized:
await self._init(self._runtime)
if self._is_running:
raise RuntimeError("The team cannot be loaded while it is running.")
self._is_running = True
try:
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
team_state = TeamState.model_validate(state)
self._team_id = team_state.team_id
await self._runtime.load_state(team_state.agent_states)
finally:
# Indicate that the team is no longer running.
self._is_running = False

View File

@ -1,9 +1,10 @@
from typing import Any, List
from typing import Any, List, Mapping
from autogen_core import DefaultTopicId, MessageContext, event, rpc
from ...base import ChatAgent, Response
from ...messages import ChatMessage
from ...state import ChatAgentContainerState
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
from ._sequential_routed_agent import SequentialRoutedAgent
@ -75,3 +76,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")
async def save_state(self) -> Mapping[str, Any]:
agent_state = await self._agent.save_state()
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
container_state = ChatAgentContainerState.model_validate(state)
self._message_buffer = list(container_state.message_buffer)
await self._agent.load_state(container_state.agent_state)

View File

@ -1,6 +1,6 @@
import json
import logging
from typing import Any, Dict, List
from typing import Any, Dict, List, Mapping
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
from autogen_core.components.models import (
@ -13,6 +13,7 @@ from autogen_core.components.models import (
from .... import TRACE_LOGGER_NAME
from ....base import Response, TerminationCondition
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
from ....state import MagenticOneOrchestratorState
from .._base_group_chat_manager import BaseGroupChatManager
from .._events import (
GroupChatAgentResponse,
@ -178,6 +179,28 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
async def validate_group_state(self, message: ChatMessage | None) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:
state = MagenticOneOrchestratorState(
message_thread=list(self._message_thread),
current_turn=self._current_turn,
task=self._task,
facts=self._facts,
plan=self._plan,
n_rounds=self._n_rounds,
n_stalls=self._n_stalls,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
self._message_thread = orchestrator_state.message_thread
self._current_turn = orchestrator_state.current_turn
self._task = orchestrator_state.task
self._facts = orchestrator_state.facts
self._plan = orchestrator_state.plan
self._n_rounds = orchestrator_state.n_rounds
self._n_stalls = orchestrator_state.n_stalls
async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
return ""

View File

@ -1,7 +1,8 @@
from typing import Callable, List
from typing import Any, Callable, List, Mapping
from ...base import ChatAgent, TerminationCondition
from ...messages import AgentMessage, ChatMessage
from ...state import RoundRobinManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
@ -38,6 +39,20 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
await self._termination_condition.reset()
self._next_speaker_index = 0
async def save_state(self) -> Mapping[str, Any]:
state = RoundRobinManagerState(
message_thread=list(self._message_thread),
current_turn=self._current_turn,
next_speaker_index=self._next_speaker_index,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
round_robin_state = RoundRobinManagerState.model_validate(state)
self._message_thread = list(round_robin_state.message_thread)
self._current_turn = round_robin_state.current_turn
self._next_speaker_index = round_robin_state.next_speaker_index
async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index

View File

@ -1,10 +1,10 @@
import logging
import re
from typing import Callable, Dict, List, Sequence
from typing import Any, Callable, Dict, List, Mapping, Sequence
from autogen_core.components.models import ChatCompletionClient, SystemMessage
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ... import TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import (
AgentMessage,
@ -16,11 +16,11 @@ from ...messages import (
ToolCallMessage,
ToolCallResultMessage,
)
from ...state import SelectorManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class SelectorGroupChatManager(BaseGroupChatManager):
@ -64,6 +64,20 @@ class SelectorGroupChatManager(BaseGroupChatManager):
await self._termination_condition.reset()
self._previous_speaker = None
async def save_state(self) -> Mapping[str, Any]:
state = SelectorManagerState(
message_thread=list(self._message_thread),
current_turn=self._current_turn,
previous_speaker=self._previous_speaker,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
selector_state = SelectorManagerState.model_validate(state)
self._message_thread = list(selector_state.message_thread)
self._current_turn = selector_state.current_turn
self._previous_speaker = selector_state.previous_speaker
async def select_speaker(self, thread: List[AgentMessage]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client,
with the selector function as override if it returns a speaker name.
@ -121,7 +135,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
select_speaker_prompt = self._selector_prompt.format(
roles=roles, participants=str(participants), history=history
)
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
response = await self._model_client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = self._mentioned_agents(response.content, self._participant_topic_types)

View File

@ -1,14 +1,11 @@
import logging
from typing import Callable, List
from typing import Any, Callable, List, Mapping
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import AgentMessage, ChatMessage, HandoffMessage
from ...state import SwarmManagerState
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class SwarmGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker based on handoff message only."""
@ -77,6 +74,20 @@ class SwarmGroupChatManager(BaseGroupChatManager):
return self._current_speaker
return self._current_speaker
async def save_state(self) -> Mapping[str, Any]:
state = SwarmManagerState(
message_thread=list(self._message_thread),
current_turn=self._current_turn,
current_speaker=self._current_speaker,
)
return state.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
swarm_state = SwarmManagerState.model_validate(state)
self._message_thread = list(swarm_state.message_thread)
self._current_turn = swarm_state.current_turn
self._current_speaker = swarm_state.current_speaker
class Swarm(BaseGroupChat):
"""A group chat team that selects the next speaker based on handoff message only.

View File

@ -112,12 +112,12 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
tool_use_agent = AssistantAgent(
agent = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
result = await tool_use_agent.run(task="task")
result = await agent.run(task="task")
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].models_usage is None
@ -135,13 +135,24 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
# Test streaming.
mock._curr_index = 0 # pyright: ignore
index = 0
async for message in tool_use_agent.run_stream(task="task"):
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
else:
assert message == result.messages[index]
index += 1
# Test state saving and loading.
state = await agent.save_state()
agent2 = AssistantAgent(
"tool_use_agent",
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
)
await agent2.load_state(state)
state2 = await agent2.save_state()
assert state == state2
@pytest.mark.asyncio
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -28,12 +28,15 @@ from autogen_agentchat.teams import (
SelectorGroupChat,
Swarm,
)
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken, FunctionCall
from autogen_core import AgentId, CancellationToken, FunctionCall
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient
from autogen_ext.models import OpenAIChatCompletionClient, ReplayChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
@ -217,6 +220,38 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1:] == result_2.messages[1:]
@pytest.mark.asyncio
async def test_round_robin_group_chat_state() -> None:
model_client = ReplayChatCompletionClient(
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
)
agent1 = AssistantAgent("agent1", model_client=model_client)
agent2 = AssistantAgent("agent2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
await team1.run(task="Write a program that prints 'Hello, world!'")
state = await team1.save_state()
agent3 = AssistantAgent("agent1", model_client=model_client)
agent4 = AssistantAgent("agent2", model_client=model_client)
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
) # pyright: ignore
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
RoundRobinGroupChatManager, # pyright: ignore
) # pyright: ignore
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
@pytest.mark.asyncio
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
@ -528,6 +563,42 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
assert result2 == result
@pytest.mark.asyncio
async def test_selector_group_chat_state() -> None:
model_client = ReplayChatCompletionClient(
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
)
agent1 = AssistantAgent("agent1", model_client=model_client)
agent2 = AssistantAgent("agent2", model_client=model_client)
termination = TextMentionTermination("TERMINATE")
team1 = SelectorGroupChat(
participants=[agent1, agent2], termination_condition=termination, model_client=model_client
)
await team1.run(task="Write a program that prints 'Hello, world!'")
state = await team1.save_state()
agent3 = AssistantAgent("agent1", model_client=model_client)
agent4 = AssistantAgent("agent2", model_client=model_client)
team2 = SelectorGroupChat(
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
assert agent3._model_context == agent1._model_context # pyright: ignore
assert agent4._model_context == agent2._model_context # pyright: ignore
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
) # pyright: ignore
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
SelectorGroupChatManager, # pyright: ignore
) # pyright: ignore
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
@pytest.mark.asyncio
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
@ -768,6 +839,26 @@ async def test_swarm_handoff() -> None:
assert message == result.messages[index]
index += 1
# Test save and load.
state = await team.save_state()
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team._team_id), # pyright: ignore
SwarmGroupChatManager, # pyright: ignore
) # pyright: ignore
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
SwarmGroupChatManager, # pyright: ignore
) # pyright: ignore
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
@pytest.mark.asyncio
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:

View File

@ -16,7 +16,8 @@ from autogen_agentchat.messages import (
from autogen_agentchat.teams import (
MagenticOneGroupChat,
)
from autogen_core import CancellationToken
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
from autogen_core import AgentId, CancellationToken
from autogen_ext.models import ReplayChatCompletionClient
from utils import FileLogHandler
@ -121,6 +122,27 @@ async def test_magentic_one_group_chat_basic() -> None:
assert result.messages[4].content == "print('Hello, world!')"
assert result.stop_reason is not None and result.stop_reason == "Because"
# Test save and load.
state = await team.save_state()
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
await team2.load_state(state)
state2 = await team2.save_state()
assert state == state2
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team._team_id), # pyright: ignore
MagenticOneOrchestrator, # pyright: ignore
) # pyright: ignore
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
MagenticOneOrchestrator, # pyright: ignore
) # pyright: ignore
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
assert manager_1._task == manager_2._task # pyright: ignore
assert manager_1._facts == manager_2._facts # pyright: ignore
assert manager_1._plan == manager_2._plan # pyright: ignore
assert manager_1._n_rounds == manager_2._n_rounds # pyright: ignore
assert manager_1._n_stalls == manager_2._n_stalls # pyright: ignore
@pytest.mark.asyncio
async def test_magentic_one_group_chat_with_stalls() -> None:

View File

@ -48,6 +48,12 @@ A dynamic team that uses handoffs to pass tasks between agents.
How to build custom agents.
:::
:::{grid-item-card} {fas}`users;pst-color-primary` State Management
:link: ./state.html
How to manage state in agents and teams.
:::
::::
```{toctree}
@ -61,4 +67,5 @@ selector-group-chat
swarm
termination
custom-agents
state
```

View File

@ -0,0 +1,299 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Managing State \n",
"\n",
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
"\n",
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
" \n",
"\n",
"## Saving and Loading Agents\n",
"\n",
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"In Tanganyika's depths so wide and deep, \n",
"Ancient secrets in still waters sleep, \n",
"Ripples tell tales that time longs to keep. \n"
]
}
],
"source": [
"from autogen_agentchat.agents import AssistantAgent\n",
"from autogen_agentchat.conditions import MaxMessageTermination\n",
"from autogen_agentchat.messages import TextMessage\n",
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
"from autogen_agentchat.ui import Console\n",
"from autogen_core import CancellationToken\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"assistant_agent = AssistantAgent(\n",
" name=\"assistant_agent\",\n",
" system_message=\"You are a helpful assistant\",\n",
" model_client=OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-2024-08-06\",\n",
" # api_key=\"YOUR_API_KEY\",\n",
" ),\n",
")\n",
"\n",
"# Use asyncio.run(...) when running in a script.\n",
"response = await assistant_agent.on_messages(\n",
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
")\n",
"print(response.chat_message.content)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths so wide and deep, \\nAncient secrets in still waters sleep, \\nRipples tell tales that time longs to keep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
]
}
],
"source": [
"agent_state = await assistant_agent.save_state()\n",
"print(agent_state)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The last line of the poem I wrote was: \n",
"\"Ripples tell tales that time longs to keep.\"\n"
]
}
],
"source": [
"new_assistant_agent = AssistantAgent(\n",
" name=\"assistant_agent\",\n",
" system_message=\"You are a helpful assistant\",\n",
" model_client=OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-2024-08-06\",\n",
" ),\n",
")\n",
"await new_assistant_agent.load_state(agent_state)\n",
"\n",
"# Use asyncio.run(...) when running in a script.\n",
"response = await new_assistant_agent.on_messages(\n",
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
")\n",
"print(response.chat_message.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```{note}\n",
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Saving and Loading Teams \n",
"\n",
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
"\n",
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
"\n",
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"Write a beautiful poem 3-line about lake tangayika\n",
"---------- assistant_agent ----------\n",
"In Tanganyika's depths, where light gently weaves, \n",
"Silver reflections dance on ancient water's face, \n",
"Whispered stories of time in the rippling leaves. \n",
"[Prompt tokens: 29, Completion tokens: 36]\n",
"---------- Summary ----------\n",
"Number of messages: 2\n",
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
"Total prompt tokens: 29\n",
"Total completion tokens: 36\n",
"Duration: 1.16 seconds\n"
]
}
],
"source": [
"# Define a team.\n",
"assistant_agent = AssistantAgent(\n",
" name=\"assistant_agent\",\n",
" system_message=\"You are a helpful assistant\",\n",
" model_client=OpenAIChatCompletionClient(\n",
" model=\"gpt-4o-2024-08-06\",\n",
" ),\n",
")\n",
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
"\n",
"# Run the team and stream messages to the console.\n",
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
"\n",
"# Use asyncio.run(...) when running in a script.\n",
"await Console(stream)\n",
"\n",
"# Save the state of the agent team.\n",
"team_state = await agent_team.save_state()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---------- user ----------\n",
"What was the last line of the poem you wrote?\n",
"---------- assistant_agent ----------\n",
"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\n",
"[Prompt tokens: 28, Completion tokens: 39]\n",
"---------- Summary ----------\n",
"Number of messages: 2\n",
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
"Total prompt tokens: 28\n",
"Total completion tokens: 39\n",
"Duration: 0.95 seconds\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=39), type='TextMessage', content=\"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\")], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await agent_team.reset()\n",
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
"await Console(stream)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote.\n",
"\n",
"Note: You can serialize the state of the team to a file and load it back later."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'type': 'TextMessage', 'content': 'Write a beautiful poem 3-line about lake tangayika'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 36}, 'type': 'TextMessage', 'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \"}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/c80054be-efb2-4bc7-ba0d-900962092c44': {}, 'assistant_agent/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'c80054be-efb2-4bc7-ba0d-900962092c44'}\n",
"---------- user ----------\n",
"What was the last line of the poem you wrote?\n",
"---------- assistant_agent ----------\n",
"The last line of the poem I wrote was: \n",
"\"Whispered stories of time in the rippling leaves.\"\n",
"[Prompt tokens: 88, Completion tokens: 24]\n",
"---------- Summary ----------\n",
"Number of messages: 2\n",
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
"Total prompt tokens: 88\n",
"Total completion tokens: 24\n",
"Duration: 0.79 seconds\n"
]
},
{
"data": {
"text/plain": [
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=88, completion_tokens=24), type='TextMessage', content='The last line of the poem I wrote was: \\n\"Whispered stories of time in the rippling leaves.\"')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"print(team_state)\n",
"\n",
"# Load team state.\n",
"await agent_team.load_state(team_state)\n",
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
"await Console(stream)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -1,283 +1,283 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# User Approval for Tool Execution using Intervention Handler\n",
"\n",
"This cookbook shows how to intercept the tool execution using\n",
"an intervention hanlder, and prompt the user for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, List\n",
"\n",
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a simple message type that carries a string content."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a simple tool use agent that is capable of using tools through a\n",
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"class ToolUseAgent(RoutedAgent):\n",
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" description: str,\n",
" system_messages: List[SystemMessage],\n",
" model_client: ChatCompletionClient,\n",
" tool_schema: List[ToolSchema],\n",
" tool_agent_type: AgentType,\n",
" ) -> None:\n",
" super().__init__(description)\n",
" self._model_client = model_client\n",
" self._system_messages = system_messages\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
" # Use the tool agent to execute the tools, and get the output messages.\n",
" output_messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Extract the final response from the output messages.\n",
" final_response = output_messages[-1].content\n",
" assert isinstance(final_response, str)\n",
" return Message(content=final_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
"to prompt the user for permission to execute the tool.\n",
"\n",
"Let's create an intervention handler that intercepts the messages and prompts\n",
"user for before allowing the tool execution."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
" )\n",
" if user_input.strip().lower() != \"y\":\n",
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
" return message"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a runtime with the intervention handler registered."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# Create the runtime with the intervention handler.\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we will use a tool for Python code execution.\n",
"First, we create a Docker-based command-line code executor\n",
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
"and then use it to instantiate a built-in Python code execution tool\n",
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
"that runs code in a Docker container."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# Create the docker executor for the Python code execution tool.\n",
"docker_executor = DockerCommandLineCodeExecutor()\n",
"\n",
"# Create the Python code execution tool.\n",
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register the agents with tools and tool schema."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_enabled_agent')"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Register agents.\n",
"tool_agent_type = await ToolAgent.register(\n",
" runtime,\n",
" \"tool_executor_agent\",\n",
" lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n",
" tools=[python_tool],\n",
" ),\n",
")\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_enabled_agent\",\n",
" lambda: ToolUseAgent(\n",
" description=\"Tool Use Agent\",\n",
" system_messages=[SystemMessage(\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[python_tool.schema],\n",
" tool_agent_type=tool_agent_type,\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
"The intervention handler will prompt you for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output of the code is: **Hello, World!**\n"
]
}
],
"source": [
"# Start the runtime and the docker executor.\n",
"await docker_executor.start()\n",
"runtime.start()\n",
"\n",
"# Send a task to the tool user.\n",
"response = await runtime.send_message(\n",
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
")\n",
"print(response.content)\n",
"\n",
"# Stop the runtime and the docker executor.\n",
"await runtime.stop()\n",
"await docker_executor.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# User Approval for Tool Execution using Intervention Handler\n",
"\n",
"This cookbook shows how to intercept the tool execution using\n",
"an intervention hanlder, and prompt the user for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import Any, List\n",
"\n",
"from autogen_core import AgentId, AgentType, FunctionCall, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.base.intervention import DefaultInterventionHandler, DropMessage\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import PythonCodeExecutionTool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, ToolException, tool_agent_caller_loop\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"from autogen_ext.models import OpenAIChatCompletionClient"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's define a simple message type that carries a string content."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"@dataclass\n",
"class Message:\n",
" content: str"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a simple tool use agent that is capable of using tools through a\n",
"{py:class}`~autogen_core.components.tool_agent.ToolAgent`."
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"class ToolUseAgent(RoutedAgent):\n",
" \"\"\"An agent that uses tools to perform tasks. It executes the tools\n",
" by itself by sending the tool execution task to a ToolAgent.\"\"\"\n",
"\n",
" def __init__(\n",
" self,\n",
" description: str,\n",
" system_messages: List[SystemMessage],\n",
" model_client: ChatCompletionClient,\n",
" tool_schema: List[ToolSchema],\n",
" tool_agent_type: AgentType,\n",
" ) -> None:\n",
" super().__init__(description)\n",
" self._model_client = model_client\n",
" self._system_messages = system_messages\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" \"\"\"Handle a user message, execute the model and tools, and returns the response.\"\"\"\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"User\")]\n",
" # Use the tool agent to execute the tools, and get the output messages.\n",
" output_messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Extract the final response from the output messages.\n",
" final_response = output_messages[-1].content\n",
" assert isinstance(final_response, str)\n",
" return Message(content=final_response)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The tool use agent sends tool call requests to the tool agent to execute tools,\n",
"so we can intercept the messages sent by the tool use agent to the tool agent\n",
"to prompt the user for permission to execute the tool.\n",
"\n",
"Let's create an intervention handler that intercepts the messages and prompts\n",
"user for before allowing the tool execution."
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class ToolInterventionHandler(DefaultInterventionHandler):\n",
" async def on_send(self, message: Any, *, sender: AgentId | None, recipient: AgentId) -> Any | type[DropMessage]:\n",
" if isinstance(message, FunctionCall):\n",
" # Request user prompt for tool execution.\n",
" user_input = input(\n",
" f\"Function call: {message.name}\\nArguments: {message.arguments}\\nDo you want to execute the tool? (y/n): \"\n",
" )\n",
" if user_input.strip().lower() != \"y\":\n",
" raise ToolException(content=\"User denied tool execution.\", call_id=message.id)\n",
" return message"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, we can create a runtime with the intervention handler registered."
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"# Create the runtime with the intervention handler.\n",
"runtime = SingleThreadedAgentRuntime(intervention_handlers=[ToolInterventionHandler()])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example, we will use a tool for Python code execution.\n",
"First, we create a Docker-based command-line code executor\n",
"using {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`,\n",
"and then use it to instantiate a built-in Python code execution tool\n",
"{py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`\n",
"that runs code in a Docker container."
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"# Create the docker executor for the Python code execution tool.\n",
"docker_executor = DockerCommandLineCodeExecutor()\n",
"\n",
"# Create the Python code execution tool.\n",
"python_tool = PythonCodeExecutionTool(executor=docker_executor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Register the agents with tools and tool schema."
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_enabled_agent')"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Register agents.\n",
"tool_agent_type = await ToolAgent.register(\n",
" runtime,\n",
" \"tool_executor_agent\",\n",
" lambda: ToolAgent(\n",
" description=\"Tool Executor Agent\",\n",
" tools=[python_tool],\n",
" ),\n",
")\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_enabled_agent\",\n",
" lambda: ToolUseAgent(\n",
" description=\"Tool Use Agent\",\n",
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
" tool_schema=[python_tool.schema],\n",
" tool_agent_type=tool_agent_type,\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run the agents by starting the runtime and sending a message to the tool use agent.\n",
"The intervention handler will prompt you for permission to execute the tool."
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The output of the code is: **Hello, World!**\n"
]
}
],
"source": [
"# Start the runtime and the docker executor.\n",
"await docker_executor.start()\n",
"runtime.start()\n",
"\n",
"# Send a task to the tool user.\n",
"response = await runtime.send_message(\n",
" Message(\"Run the following Python code: print('Hello, World!')\"), AgentId(\"tool_enabled_agent\", \"default\")\n",
")\n",
"print(response.content)\n",
"\n",
"# Stop the runtime and the docker executor.\n",
"await runtime.stop()\n",
"await docker_executor.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -158,7 +158,7 @@
" super().__init__(description=description)\n",
" self._group_chat_topic_type = group_chat_topic_type\n",
" self._model_client = model_client\n",
" self._system_message = SystemMessage(system_message)\n",
" self._system_message = SystemMessage(content=system_message)\n",
" self._chat_history: List[LLMMessage] = []\n",
"\n",
" @message_handler\n",
@ -427,7 +427,7 @@
"Read the above conversation. Then select the next role from {participants} to play. Only return the role.\n",
"\"\"\"\n",
" system_message = SystemMessage(\n",
" selector_prompt.format(\n",
" content=selector_prompt.format(\n",
" roles=roles,\n",
" history=history,\n",
" participants=str(\n",

View File

@ -1,315 +1,315 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tools\n",
"\n",
"Tools are code that can be executed by an agent to perform actions. A tool\n",
"can be a simple function such as a calculator, or an API call to a third-party service\n",
"such as stock price lookup or weather forecast.\n",
"In the context of AI agents, tools are designed to be executed by agents in\n",
"response to model-generated function calls.\n",
"\n",
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
"tools and utilities for creating and running custom tools."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Built-in Tools\n",
"\n",
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
"which allows agents to execute Python code snippets.\n",
"\n",
"Here is how you create the tool and use it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, world!\n",
"\n"
]
}
],
"source": [
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"\n",
"# Create the tool.\n",
"code_executor = DockerCommandLineCodeExecutor()\n",
"await code_executor.start()\n",
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
"cancellation_token = CancellationToken()\n",
"\n",
"# Use the tool directly without an agent.\n",
"code = \"print('Hello, world!')\"\n",
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
"print(code_execution_tool.return_value_as_string(result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
"in the local command line environment.\n",
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
"and provides a simple interface to execute Python code snippets.\n",
"\n",
"Other built-in tools will be added in the future."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Function Tools\n",
"\n",
"A tool can also be a simple Python function that performs a specific action.\n",
"To create a custom function tool, you just need to create a Python function\n",
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
"\n",
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
"to inform the LLM when and how to use a given function. The description provides context\n",
"about the functions purpose and intended use cases, while type annotations inform the LLM about\n",
"the expected parameters and return type.\n",
"\n",
"For example, a simple tool to obtain the stock price of a company might look like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80.44429939059668\n"
]
}
],
"source": [
"import random\n",
"\n",
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
"from typing_extensions import Annotated\n",
"\n",
"\n",
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
" # Returns a random stock price for demonstration purposes.\n",
" return random.uniform(10, 200)\n",
"\n",
"\n",
"# Create a function tool.\n",
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
"\n",
"# Run the tool.\n",
"cancellation_token = CancellationToken()\n",
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
"\n",
"# Print the result.\n",
"print(stock_price_tool.return_value_as_string(result))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool-Equipped Agent\n",
"\n",
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
"by using it in a composition pattern.\n",
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
"as an inner agent for executing tools."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"\n",
"class ToolUseAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
" super().__init__(\"An agent with tools\")\n",
" self._system_messages: List[LLMMessage] = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" # Create a session of messages.\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Run the caller loop to handle tool calls.\n",
" messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Return the final response.\n",
" assert isinstance(messages[-1].content, str)\n",
" return Message(content=messages[-1].content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
"to handle the interaction between the model and the tool agent.\n",
"The core idea can be described using a simple control flow graph:\n",
"\n",
"![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n",
"\n",
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
"and determines whether the model has generated a tool call.\n",
"If the model has generated tool calls, then the handler sends a function call\n",
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
"to execute the tools,\n",
"and then queries the model again with the results of the tool calls.\n",
"This process continues until the model stops generating tool calls,\n",
"at which point the final response is returned to the user.\n",
"\n",
"By having the tool execution logic in a separate agent,\n",
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
"can be observed externally and intercepted if necessary.\n",
"\n",
"To run the agent, we need to create a runtime and register the agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_use_agent')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a runtime.\n",
"runtime = SingleThreadedAgentRuntime()\n",
"# Create the tools.\n",
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n",
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_use_agent\",\n",
" lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
"Let's test the agent with a question about stock price."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
]
}
],
"source": [
"# Start processing messages.\n",
"runtime.start()\n",
"# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n",
"# Stop processing messages.\n",
"await runtime.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen_core",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tools\n",
"\n",
"Tools are code that can be executed by an agent to perform actions. A tool\n",
"can be a simple function such as a calculator, or an API call to a third-party service\n",
"such as stock price lookup or weather forecast.\n",
"In the context of AI agents, tools are designed to be executed by agents in\n",
"response to model-generated function calls.\n",
"\n",
"AutoGen provides the {py:mod}`autogen_core.components.tools` module with a suite of built-in\n",
"tools and utilities for creating and running custom tools."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Built-in Tools\n",
"\n",
"One of the built-in tools is the {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool`,\n",
"which allows agents to execute Python code snippets.\n",
"\n",
"Here is how you create the tool and use it."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hello, world!\n",
"\n"
]
}
],
"source": [
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import PythonCodeExecutionTool\n",
"from autogen_ext.code_executors import DockerCommandLineCodeExecutor\n",
"\n",
"# Create the tool.\n",
"code_executor = DockerCommandLineCodeExecutor()\n",
"await code_executor.start()\n",
"code_execution_tool = PythonCodeExecutionTool(code_executor)\n",
"cancellation_token = CancellationToken()\n",
"\n",
"# Use the tool directly without an agent.\n",
"code = \"print('Hello, world!')\"\n",
"result = await code_execution_tool.run_json({\"code\": code}, cancellation_token)\n",
"print(code_execution_tool.return_value_as_string(result))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The {py:class}`~autogen_core.components.code_executor.docker_executorCommandLineCodeExecutor`\n",
"class is a built-in code executor that runs Python code snippets in a subprocess\n",
"in the local command line environment.\n",
"The {py:class}`~autogen_core.components.tools.PythonCodeExecutionTool` class wraps the code executor\n",
"and provides a simple interface to execute Python code snippets.\n",
"\n",
"Other built-in tools will be added in the future."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Custom Function Tools\n",
"\n",
"A tool can also be a simple Python function that performs a specific action.\n",
"To create a custom function tool, you just need to create a Python function\n",
"and use the {py:class}`~autogen_core.components.tools.FunctionTool` class to wrap it.\n",
"\n",
"The {py:class}`~autogen_core.components.tools.FunctionTool` class uses descriptions and type annotations\n",
"to inform the LLM when and how to use a given function. The description provides context\n",
"about the functions purpose and intended use cases, while type annotations inform the LLM about\n",
"the expected parameters and return type.\n",
"\n",
"For example, a simple tool to obtain the stock price of a company might look like this:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80.44429939059668\n"
]
}
],
"source": [
"import random\n",
"\n",
"from autogen_core import CancellationToken\n",
"from autogen_core.components.tools import FunctionTool\n",
"from typing_extensions import Annotated\n",
"\n",
"\n",
"async def get_stock_price(ticker: str, date: Annotated[str, \"Date in YYYY/MM/DD\"]) -> float:\n",
" # Returns a random stock price for demonstration purposes.\n",
" return random.uniform(10, 200)\n",
"\n",
"\n",
"# Create a function tool.\n",
"stock_price_tool = FunctionTool(get_stock_price, description=\"Get the stock price.\")\n",
"\n",
"# Run the tool.\n",
"cancellation_token = CancellationToken()\n",
"result = await stock_price_tool.run_json({\"ticker\": \"AAPL\", \"date\": \"2021/01/01\"}, cancellation_token)\n",
"\n",
"# Print the result.\n",
"print(stock_price_tool.return_value_as_string(result))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool-Equipped Agent\n",
"\n",
"To use tools with an agent, you can use {py:class}`~autogen_core.components.tool_agent.ToolAgent`,\n",
"by using it in a composition pattern.\n",
"Here is an example tool-use agent that uses {py:class}`~autogen_core.components.tool_agent.ToolAgent`\n",
"as an inner agent for executing tools."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from dataclasses import dataclass\n",
"from typing import List\n",
"\n",
"from autogen_core import AgentId, AgentInstantiationContext, MessageContext, RoutedAgent, message_handler\n",
"from autogen_core.application import SingleThreadedAgentRuntime\n",
"from autogen_core.components.models import (\n",
" ChatCompletionClient,\n",
" LLMMessage,\n",
" SystemMessage,\n",
" UserMessage,\n",
")\n",
"from autogen_core.components.tools import FunctionTool, Tool, ToolSchema\n",
"from autogen_core.tool_agent import ToolAgent, tool_agent_caller_loop\n",
"from autogen_ext.models import OpenAIChatCompletionClient\n",
"\n",
"\n",
"@dataclass\n",
"class Message:\n",
" content: str\n",
"\n",
"\n",
"class ToolUseAgent(RoutedAgent):\n",
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
" super().__init__(\"An agent with tools\")\n",
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
" self._model_client = model_client\n",
" self._tool_schema = tool_schema\n",
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
"\n",
" @message_handler\n",
" async def handle_user_message(self, message: Message, ctx: MessageContext) -> Message:\n",
" # Create a session of messages.\n",
" session: List[LLMMessage] = [UserMessage(content=message.content, source=\"user\")]\n",
" # Run the caller loop to handle tool calls.\n",
" messages = await tool_agent_caller_loop(\n",
" self,\n",
" tool_agent_id=self._tool_agent_id,\n",
" model_client=self._model_client,\n",
" input_messages=session,\n",
" tool_schema=self._tool_schema,\n",
" cancellation_token=ctx.cancellation_token,\n",
" )\n",
" # Return the final response.\n",
" assert isinstance(messages[-1].content, str)\n",
" return Message(content=messages[-1].content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The `ToolUseAgent` class uses a convenience function {py:meth}`~autogen_core.components.tool_agent.tool_agent_caller_loop`, \n",
"to handle the interaction between the model and the tool agent.\n",
"The core idea can be described using a simple control flow graph:\n",
"\n",
"![ToolUseAgent control flow graph](tool-use-agent-cfg.svg)\n",
"\n",
"The `ToolUseAgent`'s `handle_user_message` handler handles messages from the user,\n",
"and determines whether the model has generated a tool call.\n",
"If the model has generated tool calls, then the handler sends a function call\n",
"message to the {py:class}`~autogen_core.components.tool_agent.ToolAgent` agent\n",
"to execute the tools,\n",
"and then queries the model again with the results of the tool calls.\n",
"This process continues until the model stops generating tool calls,\n",
"at which point the final response is returned to the user.\n",
"\n",
"By having the tool execution logic in a separate agent,\n",
"we expose the model-tool interactions to the agent runtime as messages, so the tool executions\n",
"can be observed externally and intercepted if necessary.\n",
"\n",
"To run the agent, we need to create a runtime and register the agent."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AgentType(type='tool_use_agent')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Create a runtime.\n",
"runtime = SingleThreadedAgentRuntime()\n",
"# Create the tools.\n",
"tools: List[Tool] = [FunctionTool(get_stock_price, description=\"Get the stock price.\")]\n",
"# Register the agents.\n",
"await ToolAgent.register(runtime, \"tool_executor_agent\", lambda: ToolAgent(\"tool executor agent\", tools))\n",
"await ToolUseAgent.register(\n",
" runtime,\n",
" \"tool_use_agent\",\n",
" lambda: ToolUseAgent(\n",
" OpenAIChatCompletionClient(model=\"gpt-4o-mini\"), [tool.schema for tool in tools], \"tool_executor_agent\"\n",
" ),\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This example uses the {py:class}`autogen_core.components.models.OpenAIChatCompletionClient`,\n",
"for Azure OpenAI and other clients, see [Model Clients](./model-clients.ipynb).\n",
"Let's test the agent with a question about stock price."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The stock price of NVDA (NVIDIA Corporation) on June 1, 2024, was approximately $179.46.\n"
]
}
],
"source": [
"# Start processing messages.\n",
"runtime.start()\n",
"# Send a direct message to the tool agent.\n",
"tool_use_agent = AgentId(\"tool_use_agent\", \"default\")\n",
"response = await runtime.send_message(Message(\"What is the stock price of NVDA on 2024/06/01?\"), tool_use_agent)\n",
"print(response.content)\n",
"# Stop processing messages.\n",
"await runtime.stop()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "autogen_core",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -36,7 +36,7 @@ Read the following conversation. Then select the next role from {participants} t
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
"""
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
response = await client.create(messages=select_speaker_messages)
assert isinstance(response.content, str)
mentions = await mentioned_agents(response.content, agents)

View File

@ -92,7 +92,7 @@ def convert_messages_to_llm_messages(
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
if converted_message_2 is not None:
result.append(converted_message_2)
case FunctionExecutionResultMessage(_):
case FunctionExecutionResultMessage(content=_):
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
if converted_message_3 is not None:
result.append(converted_message_3)

View File

@ -31,7 +31,7 @@ class BaseGroupChatAgent(RoutedAgent):
super().__init__(description=description)
self._group_chat_topic_type = group_chat_topic_type
self._model_client = model_client
self._system_message = SystemMessage(system_message)
self._system_message = SystemMessage(content=system_message)
self._chat_history: List[LLMMessage] = []
self._ui_config = ui_config
self.console = Console()
@ -126,7 +126,7 @@ Read the following conversation. Then select the next role from {participants} t
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
"""
system_message = SystemMessage(selector_prompt)
system_message = SystemMessage(content=selector_prompt)
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
assert isinstance(

View File

@ -160,12 +160,14 @@ class SchedulingAssistantAgent(RoutedAgent):
self._name = name
self._model_client = model_client
self._system_messages = [
SystemMessage(f"""
SystemMessage(
content=f"""
I am a helpful AI assistant that helps schedule meetings.
If there are missing parameters, I will ask for them.
Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
""")
"""
)
]
@message_handler

View File

@ -116,10 +116,12 @@ class ClosureAgent(BaseAgent, ClosureContext):
return await self._closure(self, message, ctx)
async def save_state(self) -> Mapping[str, Any]:
raise ValueError("save_state not implemented for ClosureAgent")
"""Closure agents do not have state. So this method always returns an empty dictionary."""
return {}
async def load_state(self, state: Mapping[str, Any]) -> None:
raise ValueError("load_state not implemented for ClosureAgent")
"""Closure agents do not have state. So this method does nothing."""
pass
@classmethod
async def register_closure(

View File

@ -1,42 +1,49 @@
from dataclasses import dataclass
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from ... import FunctionCall, Image
@dataclass
class SystemMessage:
class SystemMessage(BaseModel):
content: str
type: Literal["SystemMessage"] = "SystemMessage"
@dataclass
class UserMessage:
class UserMessage(BaseModel):
content: Union[str, List[Union[str, Image]]]
# Name of the agent that sent this message
source: str
type: Literal["UserMessage"] = "UserMessage"
@dataclass
class AssistantMessage:
class AssistantMessage(BaseModel):
content: Union[str, List[FunctionCall]]
# Name of the agent that sent this message
source: str
type: Literal["AssistantMessage"] = "AssistantMessage"
@dataclass
class FunctionExecutionResult:
class FunctionExecutionResult(BaseModel):
content: str
call_id: str
@dataclass
class FunctionExecutionResultMessage:
class FunctionExecutionResultMessage(BaseModel):
content: List[FunctionExecutionResult]
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
LLMMessage = Annotated[
Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage], Field(discriminator="type")
]
@dataclass
@ -54,16 +61,14 @@ class TopLogprob:
bytes: Optional[List[int]] = None
@dataclass
class ChatCompletionTokenLogprob:
class ChatCompletionTokenLogprob(BaseModel):
token: str
logprob: float
top_logprobs: Optional[List[TopLogprob] | None] = None
bytes: Optional[List[int]] = None
@dataclass
class CreateResult:
class CreateResult(BaseModel):
finish_reason: FinishReasons
content: Union[str, List[FunctionCall]]
usage: RequestUsage

View File

@ -36,9 +36,11 @@ class FileSurfer(BaseChatAgent):
DEFAULT_DESCRIPTION = "An agent that can handle local files."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage("""
SystemMessage(
content="""
You are a helpful AI Assistant.
When given a user query, use available functions to help the user with their request."""),
When given a user query, use available functions to help the user with their request."""
),
]
def __init__(
@ -78,7 +80,7 @@ class FileSurfer(BaseChatAgent):
except BaseException:
content = f"File surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content, source=self.name))
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:

View File

@ -190,7 +190,7 @@ class MultimodalWebSurfer(BaseChatAgent):
except BaseException:
content = f"Web surfing error:\n\n{traceback.format_exc()}"
self._chat_history.append(AssistantMessage(content, source=self.name))
self._chat_history.append(AssistantMessage(content=content, source=self.name))
return Response(chat_message=TextMessage(content=content, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
@ -712,7 +712,7 @@ class MultimodalWebSurfer(BaseChatAgent):
for line in page_markdown.splitlines():
message = UserMessage(
# content=[
prompt + buffer + line,
content=prompt + buffer + line,
# ag_image,
# ],
source=self.name,

View File

@ -33,7 +33,7 @@ class LLMAgent(RoutedAgent):
@property
def _fixed_message_history_type(self) -> List[SystemMessage]:
return [SystemMessage(msg.content) for msg in self._chat_history]
return [SystemMessage(content=msg.content) for msg in self._chat_history]
@default_subscription

View File

@ -21,7 +21,8 @@ class Coder(BaseWorker):
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage("""You are a helpful AI assistant.
SystemMessage(
content="""You are a helpful AI assistant.
Solve tasks using your coding and language skills.
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
@ -31,7 +32,8 @@ When using code, you must indicate the script type in the code block. The user c
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
Reply "TERMINATE" in the end when everything is done.""")
Reply "TERMINATE" in the end when everything is done."""
)
]
def __init__(

View File

@ -23,9 +23,11 @@ class FileSurfer(BaseWorker):
DEFAULT_DESCRIPTION = "An agent that can handle local files."
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage("""
SystemMessage(
content="""
You are a helpful AI Assistant.
When given a user query, use available functions to help the user with their request."""),
When given a user query, use available functions to help the user with their request."""
),
]
def __init__(

View File

@ -817,7 +817,7 @@ When deciding between tools, consider if the request can be best addressed by:
for line in re.split(r"([\r\n]+)", page_markdown):
message = UserMessage(
# content=[
prompt + buffer + line,
content=prompt + buffer + line,
# ag_image,
# ],
source=self.metadata["type"],

View File

@ -47,7 +47,7 @@ class LedgerOrchestrator(BaseOrchestrator):
It uses a ledger (implemented as a JSON generated by the LLM) to keep track of task progress and select the next agent that should speak."""
DEFAULT_SYSTEM_MESSAGES = [
SystemMessage(ORCHESTRATOR_SYSTEM_MESSAGE),
SystemMessage(content=ORCHESTRATOR_SYSTEM_MESSAGE),
]
def __init__(