mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-03 11:20:35 +00:00
Load and Save state in AgentChat (#4436)
1. convert dataclass types to pydantic basemodel 2. add save_state and load_state for ChatAgent 3. state types for AgentChat --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
fef06fdc8a
commit
777f2abbd7
@ -127,7 +127,7 @@ async def main() -> None:
|
||||
lambda: Coder(
|
||||
model_client=client,
|
||||
system_messages=[
|
||||
SystemMessage("""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
||||
SystemMessage(content="""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
||||
|
||||
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
|
||||
)
|
||||
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import json
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
|
||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core.components.models import (
|
||||
@ -29,6 +29,7 @@ from ..messages import (
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ..state import AssistantAgentState
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@ -49,6 +50,12 @@ class Handoff(HandoffBase):
|
||||
class AssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
```{note}
|
||||
The assistant agent is not thread-safe or coroutine-safe.
|
||||
It should not be shared between multiple tasks or coroutines, and it should
|
||||
not call its methods concurrently.
|
||||
```
|
||||
|
||||
Args:
|
||||
name (str): The name of the agent.
|
||||
model_client (ChatCompletionClient): The model client to use for inference.
|
||||
@ -224,6 +231,7 @@ class AssistantAgent(BaseChatAgent):
|
||||
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||
)
|
||||
self._model_context: List[LLMMessage] = []
|
||||
self._is_running = False
|
||||
|
||||
@property
|
||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||
@ -327,3 +335,13 @@ class AssistantAgent(BaseChatAgent):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Reset the assistant agent to its initialization state."""
|
||||
self._model_context.clear()
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the assistant agent."""
|
||||
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the assistant agent"""
|
||||
assistant_agent_state = AssistantAgentState.model_validate(state)
|
||||
self._model_context.clear()
|
||||
self._model_context.extend(assistant_agent_state.llm_messages)
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, List, Sequence
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
from ..base import ChatAgent, Response, TaskResult
|
||||
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ..state import BaseState
|
||||
|
||||
|
||||
class BaseChatAgent(ChatAgent, ABC):
|
||||
@ -117,3 +118,11 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Export state. Default implementation for stateless agents."""
|
||||
return BaseState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||
BaseState.model_validate(state)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
||||
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core import CancellationToken
|
||||
|
||||
@ -54,3 +54,11 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
"""Resets the agent to its initialization state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save agent state for later restoration"""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Restore agent from saved state"""
|
||||
...
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Protocol
|
||||
from typing import Any, Mapping, Protocol
|
||||
|
||||
from ._task import TaskRunner
|
||||
|
||||
@ -7,3 +7,11 @@ class Team(TaskRunner, Protocol):
|
||||
async def reset(self) -> None:
|
||||
"""Reset the team and all its participants to its initial state."""
|
||||
...
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the current state of the team."""
|
||||
...
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the team."""
|
||||
...
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import List
|
||||
from typing import List, Literal
|
||||
|
||||
from autogen_core import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
@ -23,6 +24,8 @@ class TextMessage(BaseMessage):
|
||||
content: str
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["TextMessage"] = "TextMessage"
|
||||
|
||||
|
||||
class MultiModalMessage(BaseMessage):
|
||||
"""A multimodal message."""
|
||||
@ -30,6 +33,8 @@ class MultiModalMessage(BaseMessage):
|
||||
content: List[str | Image]
|
||||
"""The content of the message."""
|
||||
|
||||
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||
|
||||
|
||||
class StopMessage(BaseMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
@ -37,6 +42,8 @@ class StopMessage(BaseMessage):
|
||||
content: str
|
||||
"""The content for the stop message."""
|
||||
|
||||
type: Literal["StopMessage"] = "StopMessage"
|
||||
|
||||
|
||||
class HandoffMessage(BaseMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
@ -47,6 +54,8 @@ class HandoffMessage(BaseMessage):
|
||||
content: str
|
||||
"""The handoff message to the target agent."""
|
||||
|
||||
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message signaling the use of tools."""
|
||||
@ -54,6 +63,8 @@ class ToolCallMessage(BaseMessage):
|
||||
content: List[FunctionCall]
|
||||
"""The tool calls."""
|
||||
|
||||
type: Literal["ToolCallMessage"] = "ToolCallMessage"
|
||||
|
||||
|
||||
class ToolCallResultMessage(BaseMessage):
|
||||
"""A message signaling the results of tool calls."""
|
||||
@ -61,12 +72,17 @@ class ToolCallResultMessage(BaseMessage):
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The tool call results."""
|
||||
|
||||
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
||||
|
||||
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
|
||||
"""Messages for agent-to-agent communication."""
|
||||
|
||||
|
||||
AgentMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage
|
||||
AgentMessage = Annotated[
|
||||
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
"""All message types."""
|
||||
|
||||
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
"""State management for agents, teams and termination conditions."""
|
||||
|
||||
from ._states import (
|
||||
AssistantAgentState,
|
||||
BaseGroupChatManagerState,
|
||||
BaseState,
|
||||
ChatAgentContainerState,
|
||||
MagenticOneOrchestratorState,
|
||||
RoundRobinManagerState,
|
||||
SelectorManagerState,
|
||||
SwarmManagerState,
|
||||
TeamState,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseState",
|
||||
"AssistantAgentState",
|
||||
"BaseGroupChatManagerState",
|
||||
"ChatAgentContainerState",
|
||||
"RoundRobinManagerState",
|
||||
"SelectorManagerState",
|
||||
"SwarmManagerState",
|
||||
"MagenticOneOrchestratorState",
|
||||
"TeamState",
|
||||
]
|
||||
@ -0,0 +1,81 @@
|
||||
from typing import Any, List, Mapping, Optional
|
||||
|
||||
from autogen_core.components.models import (
|
||||
LLMMessage,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..messages import (
|
||||
AgentMessage,
|
||||
ChatMessage,
|
||||
)
|
||||
|
||||
|
||||
class BaseState(BaseModel):
|
||||
"""Base class for all saveable state"""
|
||||
|
||||
type: str = Field(default="BaseState")
|
||||
version: str = Field(default="1.0.0")
|
||||
|
||||
|
||||
class AssistantAgentState(BaseState):
|
||||
"""State for an assistant agent."""
|
||||
|
||||
llm_messages: List[LLMMessage] = Field(default_factory=list)
|
||||
type: str = Field(default="AssistantAgentState")
|
||||
|
||||
|
||||
class TeamState(BaseState):
|
||||
"""State for a team of agents."""
|
||||
|
||||
agent_states: Mapping[str, Any] = Field(default_factory=dict)
|
||||
team_id: str = Field(default="")
|
||||
type: str = Field(default="TeamState")
|
||||
|
||||
|
||||
class BaseGroupChatManagerState(BaseState):
|
||||
"""Base state for all group chat managers."""
|
||||
|
||||
message_thread: List[AgentMessage] = Field(default_factory=list)
|
||||
current_turn: int = Field(default=0)
|
||||
type: str = Field(default="BaseGroupChatManagerState")
|
||||
|
||||
|
||||
class ChatAgentContainerState(BaseState):
|
||||
"""State for a container of chat agents."""
|
||||
|
||||
agent_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||
message_buffer: List[ChatMessage] = Field(default_factory=list)
|
||||
type: str = Field(default="ChatAgentContainerState")
|
||||
|
||||
|
||||
class RoundRobinManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.RoundRobinGroupChat` manager."""
|
||||
|
||||
next_speaker_index: int = Field(default=0)
|
||||
type: str = Field(default="RoundRobinManagerState")
|
||||
|
||||
|
||||
class SelectorManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""
|
||||
|
||||
previous_speaker: Optional[str] = Field(default=None)
|
||||
type: str = Field(default="SelectorManagerState")
|
||||
|
||||
|
||||
class SwarmManagerState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.Swarm` manager."""
|
||||
|
||||
current_speaker: str = Field(default="")
|
||||
type: str = Field(default="SwarmManagerState")
|
||||
|
||||
|
||||
class MagenticOneOrchestratorState(BaseGroupChatManagerState):
|
||||
"""State for :class:`~autogen_agentchat.teams.MagneticOneGroupChat` orchestrator."""
|
||||
|
||||
task: str = Field(default="")
|
||||
facts: str = Field(default="")
|
||||
plan: str = Field(default="")
|
||||
n_rounds: int = Field(default=0)
|
||||
n_stalls: int = Field(default=0)
|
||||
type: str = Field(default="MagenticOneOrchestratorState")
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncGenerator, Callable, List
|
||||
from typing import Any, AsyncGenerator, Callable, List, Mapping
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
@ -20,6 +20,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ...state import TeamState
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
@ -493,3 +494,38 @@ class BaseGroupChat(Team, ABC):
|
||||
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""Save the state of the group chat team."""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be saved while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
|
||||
agent_states = await self._runtime.save_state()
|
||||
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""Load the state of the group chat team."""
|
||||
if not self._initialized:
|
||||
await self._init(self._runtime)
|
||||
|
||||
if self._is_running:
|
||||
raise RuntimeError("The team cannot be loaded while it is running.")
|
||||
self._is_running = True
|
||||
|
||||
try:
|
||||
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
|
||||
team_state = TeamState.model_validate(state)
|
||||
self._team_id = team_state.team_id
|
||||
await self._runtime.load_state(team_state.agent_states)
|
||||
finally:
|
||||
# Indicate that the team is no longer running.
|
||||
self._is_running = False
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from typing import Any, List
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||
|
||||
from ...base import ChatAgent, Response
|
||||
from ...messages import ChatMessage
|
||||
from ...state import ChatAgentContainerState
|
||||
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
@ -75,3 +76,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
agent_state = await self._agent.save_state()
|
||||
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
container_state = ChatAgentContainerState.model_validate(state)
|
||||
self._message_buffer = list(container_state.message_buffer)
|
||||
await self._agent.load_state(container_state.agent_state)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Mapping
|
||||
|
||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
|
||||
from autogen_core.components.models import (
|
||||
@ -13,6 +13,7 @@ from autogen_core.components.models import (
|
||||
from .... import TRACE_LOGGER_NAME
|
||||
from ....base import Response, TerminationCondition
|
||||
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||
from ....state import MagenticOneOrchestratorState
|
||||
from .._base_group_chat_manager import BaseGroupChatManager
|
||||
from .._events import (
|
||||
GroupChatAgentResponse,
|
||||
@ -178,6 +179,28 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = MagenticOneOrchestratorState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
task=self._task,
|
||||
facts=self._facts,
|
||||
plan=self._plan,
|
||||
n_rounds=self._n_rounds,
|
||||
n_stalls=self._n_stalls,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||
self._message_thread = orchestrator_state.message_thread
|
||||
self._current_turn = orchestrator_state.current_turn
|
||||
self._task = orchestrator_state.task
|
||||
self._facts = orchestrator_state.facts
|
||||
self._plan = orchestrator_state.plan
|
||||
self._n_rounds = orchestrator_state.n_rounds
|
||||
self._n_stalls = orchestrator_state.n_stalls
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||
return ""
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from typing import Callable, List
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage
|
||||
from ...state import RoundRobinManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
@ -38,6 +39,20 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
await self._termination_condition.reset()
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = RoundRobinManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
next_speaker_index=self._next_speaker_index,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||
self._message_thread = list(round_robin_state.message_thread)
|
||||
self._current_turn = round_robin_state.current_turn
|
||||
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
|
||||
@ -1,10 +1,10 @@
|
||||
import logging
|
||||
import re
|
||||
from typing import Callable, Dict, List, Sequence
|
||||
from typing import Any, Callable, Dict, List, Mapping, Sequence
|
||||
|
||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||
|
||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ... import TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import (
|
||||
AgentMessage,
|
||||
@ -16,11 +16,11 @@ from ...messages import (
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ...state import SelectorManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
@ -64,6 +64,20 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
await self._termination_condition.reset()
|
||||
self._previous_speaker = None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SelectorManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
previous_speaker=self._previous_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
selector_state = SelectorManagerState.model_validate(state)
|
||||
self._message_thread = list(selector_state.message_thread)
|
||||
self._current_turn = selector_state.current_turn
|
||||
self._previous_speaker = selector_state.previous_speaker
|
||||
|
||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||
with the selector function as override if it returns a speaker name.
|
||||
@ -121,7 +135,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
select_speaker_prompt = self._selector_prompt.format(
|
||||
roles=roles, participants=str(participants), history=history
|
||||
)
|
||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
response = await self._model_client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
|
||||
|
||||
@ -1,14 +1,11 @@
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
from typing import Any, Callable, List, Mapping
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage
|
||||
from ...state import SwarmManagerState
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||
@ -77,6 +74,20 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
return self._current_speaker
|
||||
return self._current_speaker
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
state = SwarmManagerState(
|
||||
message_thread=list(self._message_thread),
|
||||
current_turn=self._current_turn,
|
||||
current_speaker=self._current_speaker,
|
||||
)
|
||||
return state.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
swarm_state = SwarmManagerState.model_validate(state)
|
||||
self._message_thread = list(swarm_state.message_thread)
|
||||
self._current_turn = swarm_state.current_turn
|
||||
self._current_speaker = swarm_state.current_speaker
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat):
|
||||
"""A group chat team that selects the next speaker based on handoff message only.
|
||||
|
||||
@ -112,12 +112,12 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
]
|
||||
mock = _MockChatCompletion(chat_completions)
|
||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||
tool_use_agent = AssistantAgent(
|
||||
agent = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
result = await tool_use_agent.run(task="task")
|
||||
result = await agent.run(task="task")
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert result.messages[0].models_usage is None
|
||||
@ -135,13 +135,24 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Test streaming.
|
||||
mock._curr_index = 0 # pyright: ignore
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
state = await agent.save_state()
|
||||
agent2 = AssistantAgent(
|
||||
"tool_use_agent",
|
||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||
)
|
||||
await agent2.load_state(state)
|
||||
state2 = await agent2.save_state()
|
||||
assert state == state2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -28,12 +28,15 @@ from autogen_agentchat.teams import (
|
||||
SelectorGroupChat,
|
||||
Swarm,
|
||||
)
|
||||
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
|
||||
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
||||
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
||||
from autogen_agentchat.ui import Console
|
||||
from autogen_core import CancellationToken, FunctionCall
|
||||
from autogen_core import AgentId, CancellationToken, FunctionCall
|
||||
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from autogen_ext.models import OpenAIChatCompletionClient
|
||||
from autogen_ext.models import OpenAIChatCompletionClient, ReplayChatCompletionClient
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
@ -217,6 +220,38 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_state() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
RoundRobinGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
@ -528,6 +563,42 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
assert result2 == result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_state() -> None:
|
||||
model_client = ReplayChatCompletionClient(
|
||||
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
||||
)
|
||||
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||
termination = TextMentionTermination("TERMINATE")
|
||||
team1 = SelectorGroupChat(
|
||||
participants=[agent1, agent2], termination_condition=termination, model_client=model_client
|
||||
)
|
||||
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||
state = await team1.save_state()
|
||||
|
||||
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||
team2 = SelectorGroupChat(
|
||||
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
||||
)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||
SelectorGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
SelectorGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
model = "gpt-4o-2024-05-13"
|
||||
@ -768,6 +839,26 @@ async def test_swarm_handoff() -> None:
|
||||
assert message == result.messages[index]
|
||||
index += 1
|
||||
|
||||
# Test save and load.
|
||||
state = await team.save_state()
|
||||
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
SwarmGroupChatManager, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
||||
@ -16,7 +16,8 @@ from autogen_agentchat.messages import (
|
||||
from autogen_agentchat.teams import (
|
||||
MagenticOneGroupChat,
|
||||
)
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||
from autogen_core import AgentId, CancellationToken
|
||||
from autogen_ext.models import ReplayChatCompletionClient
|
||||
from utils import FileLogHandler
|
||||
|
||||
@ -121,6 +122,27 @@ async def test_magentic_one_group_chat_basic() -> None:
|
||||
assert result.messages[4].content == "print('Hello, world!')"
|
||||
assert result.stop_reason is not None and result.stop_reason == "Because"
|
||||
|
||||
# Test save and load.
|
||||
state = await team.save_state()
|
||||
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||
await team2.load_state(state)
|
||||
state2 = await team2.save_state()
|
||||
assert state == state2
|
||||
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||
MagenticOneOrchestrator, # pyright: ignore
|
||||
) # pyright: ignore
|
||||
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||
assert manager_1._task == manager_2._task # pyright: ignore
|
||||
assert manager_1._facts == manager_2._facts # pyright: ignore
|
||||
assert manager_1._plan == manager_2._plan # pyright: ignore
|
||||
assert manager_1._n_rounds == manager_2._n_rounds # pyright: ignore
|
||||
assert manager_1._n_stalls == manager_2._n_stalls # pyright: ignore
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_magentic_one_group_chat_with_stalls() -> None:
|
||||
|
||||
@ -48,6 +48,12 @@ A dynamic team that uses handoffs to pass tasks between agents.
|
||||
How to build custom agents.
|
||||
:::
|
||||
|
||||
:::{grid-item-card} {fas}`users;pst-color-primary` State Management
|
||||
:link: ./state.html
|
||||
|
||||
How to manage state in agents and teams.
|
||||
:::
|
||||
|
||||
::::
|
||||
|
||||
```{toctree}
|
||||
@ -61,4 +67,5 @@ selector-group-chat
|
||||
swarm
|
||||
termination
|
||||
custom-agents
|
||||
state
|
||||
```
|
||||
|
||||
@ -0,0 +1,299 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Managing State \n",
|
||||
"\n",
|
||||
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
|
||||
"\n",
|
||||
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
|
||||
" \n",
|
||||
"\n",
|
||||
"## Saving and Loading Agents\n",
|
||||
"\n",
|
||||
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
|
||||
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"In Tanganyika's depths so wide and deep, \n",
|
||||
"Ancient secrets in still waters sleep, \n",
|
||||
"Ripples tell tales that time longs to keep. \n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||
"from autogen_agentchat.messages import TextMessage\n",
|
||||
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||
"from autogen_agentchat.ui import Console\n",
|
||||
"from autogen_core import CancellationToken\n",
|
||||
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||
"\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" # api_key=\"YOUR_API_KEY\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths so wide and deep, \\nAncient secrets in still waters sleep, \\nRipples tell tales that time longs to keep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"agent_state = await assistant_agent.save_state()\n",
|
||||
"print(agent_state)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The last line of the poem I wrote was: \n",
|
||||
"\"Ripples tell tales that time longs to keep.\"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"new_assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"await new_assistant_agent.load_state(agent_state)\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"response = await new_assistant_agent.on_messages(\n",
|
||||
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
|
||||
")\n",
|
||||
"print(response.chat_message.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```{note}\n",
|
||||
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
|
||||
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Saving and Loading Teams \n",
|
||||
"\n",
|
||||
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
|
||||
"\n",
|
||||
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
|
||||
"\n",
|
||||
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"Write a beautiful poem 3-line about lake tangayika\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"In Tanganyika's depths, where light gently weaves, \n",
|
||||
"Silver reflections dance on ancient water's face, \n",
|
||||
"Whispered stories of time in the rippling leaves. \n",
|
||||
"[Prompt tokens: 29, Completion tokens: 36]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 29\n",
|
||||
"Total completion tokens: 36\n",
|
||||
"Duration: 1.16 seconds\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define a team.\n",
|
||||
"assistant_agent = AssistantAgent(\n",
|
||||
" name=\"assistant_agent\",\n",
|
||||
" system_message=\"You are a helpful assistant\",\n",
|
||||
" model_client=OpenAIChatCompletionClient(\n",
|
||||
" model=\"gpt-4o-2024-08-06\",\n",
|
||||
" ),\n",
|
||||
")\n",
|
||||
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||
"\n",
|
||||
"# Run the team and stream messages to the console.\n",
|
||||
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
|
||||
"\n",
|
||||
"# Use asyncio.run(...) when running in a script.\n",
|
||||
"await Console(stream)\n",
|
||||
"\n",
|
||||
"# Save the state of the agent team.\n",
|
||||
"team_state = await agent_team.save_state()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\n",
|
||||
"[Prompt tokens: 28, Completion tokens: 39]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 28\n",
|
||||
"Total completion tokens: 39\n",
|
||||
"Duration: 0.95 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=39), type='TextMessage', content=\"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\")], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"await agent_team.reset()\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote.\n",
|
||||
"\n",
|
||||
"Note: You can serialize the state of the team to a file and load it back later."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'type': 'TextMessage', 'content': 'Write a beautiful poem 3-line about lake tangayika'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 36}, 'type': 'TextMessage', 'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \"}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/c80054be-efb2-4bc7-ba0d-900962092c44': {}, 'assistant_agent/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'c80054be-efb2-4bc7-ba0d-900962092c44'}\n",
|
||||
"---------- user ----------\n",
|
||||
"What was the last line of the poem you wrote?\n",
|
||||
"---------- assistant_agent ----------\n",
|
||||
"The last line of the poem I wrote was: \n",
|
||||
"\"Whispered stories of time in the rippling leaves.\"\n",
|
||||
"[Prompt tokens: 88, Completion tokens: 24]\n",
|
||||
"---------- Summary ----------\n",
|
||||
"Number of messages: 2\n",
|
||||
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||
"Total prompt tokens: 88\n",
|
||||
"Total completion tokens: 24\n",
|
||||
"Duration: 0.79 seconds\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=88, completion_tokens=24), type='TextMessage', content='The last line of the poem I wrote was: \\n\"Whispered stories of time in the rippling leaves.\"')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(team_state)\n",
|
||||
"\n",
|
||||
"# Load team state.\n",
|
||||
"await agent_team.load_state(team_state)\n",
|
||||
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||
"await Console(stream)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@ -213,7 +213,7 @@
|
||||
" \"tool_enabled_agent\",\n",
|
||||
" lambda: ToolUseAgent(\n",
|
||||
" description=\"Tool Use Agent\",\n",
|
||||
" system_messages=[SystemMessage(\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||
" tool_schema=[python_tool.schema],\n",
|
||||
" tool_agent_type=tool_agent_type,\n",
|
||||
|
||||
@ -174,7 +174,7 @@
|
||||
" factory=lambda: TaxSpecialist(\n",
|
||||
" description=\"A tax specialist 1\",\n",
|
||||
" specialty=TaxSpecialty.PLANNING,\n",
|
||||
" system_messages=[SystemMessage(\"You are a tax specialist.\")],\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a tax specialist.\")],\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -184,7 +184,7 @@
|
||||
" factory=lambda: TaxSpecialist(\n",
|
||||
" description=\"A tax specialist 2\",\n",
|
||||
" specialty=TaxSpecialty.DISPUTE_RESOLUTION,\n",
|
||||
" system_messages=[SystemMessage(\"You are a tax specialist.\")],\n",
|
||||
" system_messages=[SystemMessage(content=\"You are a tax specialist.\")],\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -303,7 +303,7 @@
|
||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||
" specialty=specialty,\n",
|
||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" specialist_subscription = DefaultSubscription(agent_type=specialist_agent_type)\n",
|
||||
@ -414,7 +414,7 @@
|
||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||
" specialty=specialty,\n",
|
||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" specialist_subscription = TypeSubscription(topic_type=specialty.value, agent_type=specialist_agent_type)\n",
|
||||
@ -545,7 +545,7 @@
|
||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||
" specialty=specialty,\n",
|
||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
" for tenant in tenants:\n",
|
||||
|
||||
@ -158,7 +158,7 @@
|
||||
" super().__init__(description=description)\n",
|
||||
" self._group_chat_topic_type = group_chat_topic_type\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._system_message = SystemMessage(system_message)\n",
|
||||
" self._system_message = SystemMessage(content=system_message)\n",
|
||||
" self._chat_history: List[LLMMessage] = []\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
@ -427,7 +427,7 @@
|
||||
"Read the above conversation. Then select the next role from {participants} to play. Only return the role.\n",
|
||||
"\"\"\"\n",
|
||||
" system_message = SystemMessage(\n",
|
||||
" selector_prompt.format(\n",
|
||||
" content=selector_prompt.format(\n",
|
||||
" roles=roles,\n",
|
||||
" history=history,\n",
|
||||
" participants=str(\n",
|
||||
|
||||
@ -112,7 +112,7 @@
|
||||
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
||||
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(message.previous_results)])\n",
|
||||
" model_result = await self._model_client.create(\n",
|
||||
" [SystemMessage(system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||
" [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||
" )\n",
|
||||
" else:\n",
|
||||
" # If no previous results are provided, we can simply pass the user query to the model.\n",
|
||||
@ -174,7 +174,7 @@
|
||||
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
||||
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(worker_task.previous_results)])\n",
|
||||
" model_result = await self._model_client.create(\n",
|
||||
" [SystemMessage(system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||
" [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||
" )\n",
|
||||
" assert isinstance(model_result.content, str)\n",
|
||||
" return FinalResult(result=model_result.content)"
|
||||
|
||||
@ -146,7 +146,7 @@
|
||||
" self._buffer: Dict[int, List[IntermediateSolverResponse]] = {}\n",
|
||||
" self._system_messages = [\n",
|
||||
" SystemMessage(\n",
|
||||
" (\n",
|
||||
" content=(\n",
|
||||
" \"You are a helpful assistant with expertise in mathematics and reasoning. \"\n",
|
||||
" \"Your task is to assist in solving a math reasoning problem by providing \"\n",
|
||||
" \"a clear and detailed solution. Limit your output within 100 words, \"\n",
|
||||
|
||||
@ -343,7 +343,7 @@
|
||||
"class SimpleAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
||||
" super().__init__(\"A simple agent\")\n",
|
||||
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
||||
" self._system_messages = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
"\n",
|
||||
" @message_handler\n",
|
||||
@ -478,7 +478,7 @@
|
||||
"class SimpleAgentWithContext(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
||||
" super().__init__(\"A simple agent\")\n",
|
||||
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
||||
" self._system_messages = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._model_context = BufferedChatCompletionContext(buffer_size=5)\n",
|
||||
"\n",
|
||||
|
||||
@ -176,7 +176,7 @@
|
||||
"class ToolUseAgent(RoutedAgent):\n",
|
||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||
" super().__init__(\"An agent with tools\")\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
||||
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||
" self._model_client = model_client\n",
|
||||
" self._tool_schema = tool_schema\n",
|
||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||
|
||||
@ -36,7 +36,7 @@ Read the following conversation. Then select the next role from {participants} t
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||
"""
|
||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
||||
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||
response = await client.create(messages=select_speaker_messages)
|
||||
assert isinstance(response.content, str)
|
||||
mentions = await mentioned_agents(response.content, agents)
|
||||
|
||||
@ -92,7 +92,7 @@ def convert_messages_to_llm_messages(
|
||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||
if converted_message_2 is not None:
|
||||
result.append(converted_message_2)
|
||||
case FunctionExecutionResultMessage(_):
|
||||
case FunctionExecutionResultMessage(content=_):
|
||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||
if converted_message_3 is not None:
|
||||
result.append(converted_message_3)
|
||||
|
||||
@ -31,7 +31,7 @@ class BaseGroupChatAgent(RoutedAgent):
|
||||
super().__init__(description=description)
|
||||
self._group_chat_topic_type = group_chat_topic_type
|
||||
self._model_client = model_client
|
||||
self._system_message = SystemMessage(system_message)
|
||||
self._system_message = SystemMessage(content=system_message)
|
||||
self._chat_history: List[LLMMessage] = []
|
||||
self._ui_config = ui_config
|
||||
self.console = Console()
|
||||
@ -126,7 +126,7 @@ Read the following conversation. Then select the next role from {participants} t
|
||||
|
||||
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
|
||||
"""
|
||||
system_message = SystemMessage(selector_prompt)
|
||||
system_message = SystemMessage(content=selector_prompt)
|
||||
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
|
||||
|
||||
assert isinstance(
|
||||
|
||||
@ -160,12 +160,14 @@ class SchedulingAssistantAgent(RoutedAgent):
|
||||
self._name = name
|
||||
self._model_client = model_client
|
||||
self._system_messages = [
|
||||
SystemMessage(f"""
|
||||
SystemMessage(
|
||||
content=f"""
|
||||
I am a helpful AI assistant that helps schedule meetings.
|
||||
If there are missing parameters, I will ask for them.
|
||||
|
||||
Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
]
|
||||
|
||||
@message_handler
|
||||
|
||||
@ -116,10 +116,12 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
||||
return await self._closure(self, message, ctx)
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
raise ValueError("save_state not implemented for ClosureAgent")
|
||||
"""Closure agents do not have state. So this method always returns an empty dictionary."""
|
||||
return {}
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
raise ValueError("load_state not implemented for ClosureAgent")
|
||||
"""Closure agents do not have state. So this method does nothing."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def register_closure(
|
||||
|
||||
@ -1,42 +1,49 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from ... import FunctionCall, Image
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemMessage:
|
||||
class SystemMessage(BaseModel):
|
||||
content: str
|
||||
type: Literal["SystemMessage"] = "SystemMessage"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UserMessage:
|
||||
class UserMessage(BaseModel):
|
||||
content: Union[str, List[Union[str, Image]]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
type: Literal["UserMessage"] = "UserMessage"
|
||||
|
||||
@dataclass
|
||||
class AssistantMessage:
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
content: Union[str, List[FunctionCall]]
|
||||
|
||||
# Name of the agent that sent this message
|
||||
source: str
|
||||
|
||||
type: Literal["AssistantMessage"] = "AssistantMessage"
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResult:
|
||||
|
||||
class FunctionExecutionResult(BaseModel):
|
||||
content: str
|
||||
call_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionExecutionResultMessage:
|
||||
class FunctionExecutionResultMessage(BaseModel):
|
||||
content: List[FunctionExecutionResult]
|
||||
|
||||
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
|
||||
|
||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
||||
|
||||
LLMMessage = Annotated[
|
||||
Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -54,16 +61,14 @@ class TopLogprob:
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatCompletionTokenLogprob:
|
||||
class ChatCompletionTokenLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: Optional[List[TopLogprob] | None] = None
|
||||
bytes: Optional[List[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateResult:
|
||||
class CreateResult(BaseModel):
|
||||
finish_reason: FinishReasons
|
||||
content: Union[str, List[FunctionCall]]
|
||||
usage: RequestUsage
|
||||
|
||||
@ -36,9 +36,11 @@ class FileSurfer(BaseChatAgent):
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""),
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
@ -78,7 +80,7 @@ class FileSurfer(BaseChatAgent):
|
||||
|
||||
except BaseException:
|
||||
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
|
||||
@ -190,7 +190,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
||||
|
||||
except BaseException:
|
||||
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
||||
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||
|
||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||
@ -712,7 +712,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
||||
for line in page_markdown.splitlines():
|
||||
message = UserMessage(
|
||||
# content=[
|
||||
prompt + buffer + line,
|
||||
content=prompt + buffer + line,
|
||||
# ag_image,
|
||||
# ],
|
||||
source=self.name,
|
||||
|
||||
@ -33,7 +33,7 @@ class LLMAgent(RoutedAgent):
|
||||
|
||||
@property
|
||||
def _fixed_message_history_type(self) -> List[SystemMessage]:
|
||||
return [SystemMessage(msg.content) for msg in self._chat_history]
|
||||
return [SystemMessage(content=msg.content) for msg in self._chat_history]
|
||||
|
||||
|
||||
@default_subscription
|
||||
|
||||
@ -21,7 +21,8 @@ class Coder(BaseWorker):
|
||||
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""You are a helpful AI assistant.
|
||||
SystemMessage(
|
||||
content="""You are a helpful AI assistant.
|
||||
Solve tasks using your coding and language skills.
|
||||
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||
@ -31,7 +32,8 @@ When using code, you must indicate the script type in the code block. The user c
|
||||
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||
Reply "TERMINATE" in the end when everything is done.""")
|
||||
Reply "TERMINATE" in the end when everything is done."""
|
||||
)
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -23,9 +23,11 @@ class FileSurfer(BaseWorker):
|
||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage("""
|
||||
SystemMessage(
|
||||
content="""
|
||||
You are a helpful AI Assistant.
|
||||
When given a user query, use available functions to help the user with their request."""),
|
||||
When given a user query, use available functions to help the user with their request."""
|
||||
),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -817,7 +817,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
||||
for line in re.split(r"([\r\n]+)", page_markdown):
|
||||
message = UserMessage(
|
||||
# content=[
|
||||
prompt + buffer + line,
|
||||
content=prompt + buffer + line,
|
||||
# ag_image,
|
||||
# ],
|
||||
source=self.metadata["type"],
|
||||
|
||||
@ -47,7 +47,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
||||
It uses a ledger (implemented as a JSON generated by the LLM) to keep track of task progress and select the next agent that should speak."""
|
||||
|
||||
DEFAULT_SYSTEM_MESSAGES = [
|
||||
SystemMessage(ORCHESTRATOR_SYSTEM_MESSAGE),
|
||||
SystemMessage(content=ORCHESTRATOR_SYSTEM_MESSAGE),
|
||||
]
|
||||
|
||||
def __init__(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user