mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-13 16:44:32 +00:00
Load and Save state in AgentChat (#4436)
1. convert dataclass types to pydantic basemodel 2. add save_state and load_state for ChatAgent 3. state types for AgentChat --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
fef06fdc8a
commit
777f2abbd7
@ -127,7 +127,7 @@ async def main() -> None:
|
|||||||
lambda: Coder(
|
lambda: Coder(
|
||||||
model_client=client,
|
model_client=client,
|
||||||
system_messages=[
|
system_messages=[
|
||||||
SystemMessage("""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
SystemMessage(content="""You are a general-purpose AI assistant and can handle many questions -- but you don't have access to a web browser. However, the user you are talking to does have a browser, and you can see the screen. Provide short direct instructions to them to take you where you need to go to answer the initial question posed to you.
|
||||||
|
|
||||||
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
|
Once the user has taken the final necessary action to complete the task, and you have fully addressed the initial request, reply with the word TERMINATE.""",
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence
|
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Mapping, Sequence
|
||||||
|
|
||||||
from autogen_core import CancellationToken, FunctionCall
|
from autogen_core import CancellationToken, FunctionCall
|
||||||
from autogen_core.components.models import (
|
from autogen_core.components.models import (
|
||||||
@ -29,6 +29,7 @@ from ..messages import (
|
|||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolCallResultMessage,
|
ToolCallResultMessage,
|
||||||
)
|
)
|
||||||
|
from ..state import AssistantAgentState
|
||||||
from ._base_chat_agent import BaseChatAgent
|
from ._base_chat_agent import BaseChatAgent
|
||||||
|
|
||||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||||
@ -49,6 +50,12 @@ class Handoff(HandoffBase):
|
|||||||
class AssistantAgent(BaseChatAgent):
|
class AssistantAgent(BaseChatAgent):
|
||||||
"""An agent that provides assistance with tool use.
|
"""An agent that provides assistance with tool use.
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
The assistant agent is not thread-safe or coroutine-safe.
|
||||||
|
It should not be shared between multiple tasks or coroutines, and it should
|
||||||
|
not call its methods concurrently.
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The name of the agent.
|
name (str): The name of the agent.
|
||||||
model_client (ChatCompletionClient): The model client to use for inference.
|
model_client (ChatCompletionClient): The model client to use for inference.
|
||||||
@ -224,6 +231,7 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}"
|
||||||
)
|
)
|
||||||
self._model_context: List[LLMMessage] = []
|
self._model_context: List[LLMMessage] = []
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def produced_message_types(self) -> List[type[ChatMessage]]:
|
def produced_message_types(self) -> List[type[ChatMessage]]:
|
||||||
@ -327,3 +335,13 @@ class AssistantAgent(BaseChatAgent):
|
|||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
"""Reset the assistant agent to its initialization state."""
|
"""Reset the assistant agent to its initialization state."""
|
||||||
self._model_context.clear()
|
self._model_context.clear()
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
"""Save the current state of the assistant agent."""
|
||||||
|
return AssistantAgentState(llm_messages=self._model_context.copy()).model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
"""Load the state of the assistant agent"""
|
||||||
|
assistant_agent_state = AssistantAgentState.model_validate(state)
|
||||||
|
self._model_context.clear()
|
||||||
|
self._model_context.extend(assistant_agent_state.llm_messages)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, List, Sequence
|
from typing import Any, AsyncGenerator, List, Mapping, Sequence
|
||||||
|
|
||||||
from autogen_core import CancellationToken
|
from autogen_core import CancellationToken
|
||||||
|
|
||||||
from ..base import ChatAgent, Response, TaskResult
|
from ..base import ChatAgent, Response, TaskResult
|
||||||
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
from ..messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||||
|
from ..state import BaseState
|
||||||
|
|
||||||
|
|
||||||
class BaseChatAgent(ChatAgent, ABC):
|
class BaseChatAgent(ChatAgent, ABC):
|
||||||
@ -117,3 +118,11 @@ class BaseChatAgent(ChatAgent, ABC):
|
|||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
"""Resets the agent to its initialization state."""
|
"""Resets the agent to its initialization state."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
"""Export state. Default implementation for stateless agents."""
|
||||||
|
return BaseState().model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
"""Restore agent from saved state. Default implementation for stateless agents."""
|
||||||
|
BaseState.model_validate(state)
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable
|
from typing import Any, AsyncGenerator, List, Mapping, Protocol, Sequence, runtime_checkable
|
||||||
|
|
||||||
from autogen_core import CancellationToken
|
from autogen_core import CancellationToken
|
||||||
|
|
||||||
@ -54,3 +54,11 @@ class ChatAgent(TaskRunner, Protocol):
|
|||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
"""Resets the agent to its initialization state."""
|
"""Resets the agent to its initialization state."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
"""Save agent state for later restoration"""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
"""Restore agent from saved state"""
|
||||||
|
...
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Protocol
|
from typing import Any, Mapping, Protocol
|
||||||
|
|
||||||
from ._task import TaskRunner
|
from ._task import TaskRunner
|
||||||
|
|
||||||
@ -7,3 +7,11 @@ class Team(TaskRunner, Protocol):
|
|||||||
async def reset(self) -> None:
|
async def reset(self) -> None:
|
||||||
"""Reset the team and all its participants to its initial state."""
|
"""Reset the team and all its participants to its initial state."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
"""Save the current state of the team."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
"""Load the state of the team."""
|
||||||
|
...
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from typing import List
|
from typing import List, Literal
|
||||||
|
|
||||||
from autogen_core import FunctionCall, Image
|
from autogen_core import FunctionCall, Image
|
||||||
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
|
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
|
||||||
class BaseMessage(BaseModel):
|
class BaseMessage(BaseModel):
|
||||||
@ -23,6 +24,8 @@ class TextMessage(BaseMessage):
|
|||||||
content: str
|
content: str
|
||||||
"""The content of the message."""
|
"""The content of the message."""
|
||||||
|
|
||||||
|
type: Literal["TextMessage"] = "TextMessage"
|
||||||
|
|
||||||
|
|
||||||
class MultiModalMessage(BaseMessage):
|
class MultiModalMessage(BaseMessage):
|
||||||
"""A multimodal message."""
|
"""A multimodal message."""
|
||||||
@ -30,6 +33,8 @@ class MultiModalMessage(BaseMessage):
|
|||||||
content: List[str | Image]
|
content: List[str | Image]
|
||||||
"""The content of the message."""
|
"""The content of the message."""
|
||||||
|
|
||||||
|
type: Literal["MultiModalMessage"] = "MultiModalMessage"
|
||||||
|
|
||||||
|
|
||||||
class StopMessage(BaseMessage):
|
class StopMessage(BaseMessage):
|
||||||
"""A message requesting stop of a conversation."""
|
"""A message requesting stop of a conversation."""
|
||||||
@ -37,6 +42,8 @@ class StopMessage(BaseMessage):
|
|||||||
content: str
|
content: str
|
||||||
"""The content for the stop message."""
|
"""The content for the stop message."""
|
||||||
|
|
||||||
|
type: Literal["StopMessage"] = "StopMessage"
|
||||||
|
|
||||||
|
|
||||||
class HandoffMessage(BaseMessage):
|
class HandoffMessage(BaseMessage):
|
||||||
"""A message requesting handoff of a conversation to another agent."""
|
"""A message requesting handoff of a conversation to another agent."""
|
||||||
@ -47,6 +54,8 @@ class HandoffMessage(BaseMessage):
|
|||||||
content: str
|
content: str
|
||||||
"""The handoff message to the target agent."""
|
"""The handoff message to the target agent."""
|
||||||
|
|
||||||
|
type: Literal["HandoffMessage"] = "HandoffMessage"
|
||||||
|
|
||||||
|
|
||||||
class ToolCallMessage(BaseMessage):
|
class ToolCallMessage(BaseMessage):
|
||||||
"""A message signaling the use of tools."""
|
"""A message signaling the use of tools."""
|
||||||
@ -54,6 +63,8 @@ class ToolCallMessage(BaseMessage):
|
|||||||
content: List[FunctionCall]
|
content: List[FunctionCall]
|
||||||
"""The tool calls."""
|
"""The tool calls."""
|
||||||
|
|
||||||
|
type: Literal["ToolCallMessage"] = "ToolCallMessage"
|
||||||
|
|
||||||
|
|
||||||
class ToolCallResultMessage(BaseMessage):
|
class ToolCallResultMessage(BaseMessage):
|
||||||
"""A message signaling the results of tool calls."""
|
"""A message signaling the results of tool calls."""
|
||||||
@ -61,12 +72,17 @@ class ToolCallResultMessage(BaseMessage):
|
|||||||
content: List[FunctionExecutionResult]
|
content: List[FunctionExecutionResult]
|
||||||
"""The tool call results."""
|
"""The tool call results."""
|
||||||
|
|
||||||
|
type: Literal["ToolCallResultMessage"] = "ToolCallResultMessage"
|
||||||
|
|
||||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
|
||||||
|
ChatMessage = Annotated[TextMessage | MultiModalMessage | StopMessage | HandoffMessage, Field(discriminator="type")]
|
||||||
"""Messages for agent-to-agent communication."""
|
"""Messages for agent-to-agent communication."""
|
||||||
|
|
||||||
|
|
||||||
AgentMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage
|
AgentMessage = Annotated[
|
||||||
|
TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ToolCallMessage | ToolCallResultMessage,
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
"""All message types."""
|
"""All message types."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,25 @@
|
|||||||
|
"""State management for agents, teams and termination conditions."""
|
||||||
|
|
||||||
|
from ._states import (
|
||||||
|
AssistantAgentState,
|
||||||
|
BaseGroupChatManagerState,
|
||||||
|
BaseState,
|
||||||
|
ChatAgentContainerState,
|
||||||
|
MagenticOneOrchestratorState,
|
||||||
|
RoundRobinManagerState,
|
||||||
|
SelectorManagerState,
|
||||||
|
SwarmManagerState,
|
||||||
|
TeamState,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseState",
|
||||||
|
"AssistantAgentState",
|
||||||
|
"BaseGroupChatManagerState",
|
||||||
|
"ChatAgentContainerState",
|
||||||
|
"RoundRobinManagerState",
|
||||||
|
"SelectorManagerState",
|
||||||
|
"SwarmManagerState",
|
||||||
|
"MagenticOneOrchestratorState",
|
||||||
|
"TeamState",
|
||||||
|
]
|
||||||
@ -0,0 +1,81 @@
|
|||||||
|
from typing import Any, List, Mapping, Optional
|
||||||
|
|
||||||
|
from autogen_core.components.models import (
|
||||||
|
LLMMessage,
|
||||||
|
)
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ..messages import (
|
||||||
|
AgentMessage,
|
||||||
|
ChatMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseState(BaseModel):
|
||||||
|
"""Base class for all saveable state"""
|
||||||
|
|
||||||
|
type: str = Field(default="BaseState")
|
||||||
|
version: str = Field(default="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
|
class AssistantAgentState(BaseState):
|
||||||
|
"""State for an assistant agent."""
|
||||||
|
|
||||||
|
llm_messages: List[LLMMessage] = Field(default_factory=list)
|
||||||
|
type: str = Field(default="AssistantAgentState")
|
||||||
|
|
||||||
|
|
||||||
|
class TeamState(BaseState):
|
||||||
|
"""State for a team of agents."""
|
||||||
|
|
||||||
|
agent_states: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
|
team_id: str = Field(default="")
|
||||||
|
type: str = Field(default="TeamState")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseGroupChatManagerState(BaseState):
|
||||||
|
"""Base state for all group chat managers."""
|
||||||
|
|
||||||
|
message_thread: List[AgentMessage] = Field(default_factory=list)
|
||||||
|
current_turn: int = Field(default=0)
|
||||||
|
type: str = Field(default="BaseGroupChatManagerState")
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAgentContainerState(BaseState):
|
||||||
|
"""State for a container of chat agents."""
|
||||||
|
|
||||||
|
agent_state: Mapping[str, Any] = Field(default_factory=dict)
|
||||||
|
message_buffer: List[ChatMessage] = Field(default_factory=list)
|
||||||
|
type: str = Field(default="ChatAgentContainerState")
|
||||||
|
|
||||||
|
|
||||||
|
class RoundRobinManagerState(BaseGroupChatManagerState):
|
||||||
|
"""State for :class:`~autogen_agentchat.teams.RoundRobinGroupChat` manager."""
|
||||||
|
|
||||||
|
next_speaker_index: int = Field(default=0)
|
||||||
|
type: str = Field(default="RoundRobinManagerState")
|
||||||
|
|
||||||
|
|
||||||
|
class SelectorManagerState(BaseGroupChatManagerState):
|
||||||
|
"""State for :class:`~autogen_agentchat.teams.SelectorGroupChat` manager."""
|
||||||
|
|
||||||
|
previous_speaker: Optional[str] = Field(default=None)
|
||||||
|
type: str = Field(default="SelectorManagerState")
|
||||||
|
|
||||||
|
|
||||||
|
class SwarmManagerState(BaseGroupChatManagerState):
|
||||||
|
"""State for :class:`~autogen_agentchat.teams.Swarm` manager."""
|
||||||
|
|
||||||
|
current_speaker: str = Field(default="")
|
||||||
|
type: str = Field(default="SwarmManagerState")
|
||||||
|
|
||||||
|
|
||||||
|
class MagenticOneOrchestratorState(BaseGroupChatManagerState):
|
||||||
|
"""State for :class:`~autogen_agentchat.teams.MagneticOneGroupChat` orchestrator."""
|
||||||
|
|
||||||
|
task: str = Field(default="")
|
||||||
|
facts: str = Field(default="")
|
||||||
|
plan: str = Field(default="")
|
||||||
|
n_rounds: int = Field(default=0)
|
||||||
|
n_stalls: int = Field(default=0)
|
||||||
|
type: str = Field(default="MagenticOneOrchestratorState")
|
||||||
@ -2,7 +2,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncGenerator, Callable, List
|
from typing import Any, AsyncGenerator, Callable, List, Mapping
|
||||||
|
|
||||||
from autogen_core import (
|
from autogen_core import (
|
||||||
AgentId,
|
AgentId,
|
||||||
@ -20,6 +20,7 @@ from autogen_core.application import SingleThreadedAgentRuntime
|
|||||||
from ... import EVENT_LOGGER_NAME
|
from ... import EVENT_LOGGER_NAME
|
||||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
from ...messages import AgentMessage, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage, TextMessage
|
||||||
|
from ...state import TeamState
|
||||||
from ._chat_agent_container import ChatAgentContainer
|
from ._chat_agent_container import ChatAgentContainer
|
||||||
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
from ._events import GroupChatMessage, GroupChatReset, GroupChatStart, GroupChatTermination
|
||||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
@ -493,3 +494,38 @@ class BaseGroupChat(Team, ABC):
|
|||||||
|
|
||||||
# Indicate that the team is no longer running.
|
# Indicate that the team is no longer running.
|
||||||
self._is_running = False
|
self._is_running = False
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
"""Save the state of the group chat team."""
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("The group chat has not been initialized. It must be run before it can be saved.")
|
||||||
|
|
||||||
|
if self._is_running:
|
||||||
|
raise RuntimeError("The team cannot be saved while it is running.")
|
||||||
|
self._is_running = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Save the state of the runtime. This will save the state of the participants and the group chat manager.
|
||||||
|
agent_states = await self._runtime.save_state()
|
||||||
|
return TeamState(agent_states=agent_states, team_id=self._team_id).model_dump()
|
||||||
|
finally:
|
||||||
|
# Indicate that the team is no longer running.
|
||||||
|
self._is_running = False
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
"""Load the state of the group chat team."""
|
||||||
|
if not self._initialized:
|
||||||
|
await self._init(self._runtime)
|
||||||
|
|
||||||
|
if self._is_running:
|
||||||
|
raise RuntimeError("The team cannot be loaded while it is running.")
|
||||||
|
self._is_running = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Load the state of the runtime. This will load the state of the participants and the group chat manager.
|
||||||
|
team_state = TeamState.model_validate(state)
|
||||||
|
self._team_id = team_state.team_id
|
||||||
|
await self._runtime.load_state(team_state.agent_states)
|
||||||
|
finally:
|
||||||
|
# Indicate that the team is no longer running.
|
||||||
|
self._is_running = False
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
from typing import Any, List
|
from typing import Any, List, Mapping
|
||||||
|
|
||||||
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
from autogen_core import DefaultTopicId, MessageContext, event, rpc
|
||||||
|
|
||||||
from ...base import ChatAgent, Response
|
from ...base import ChatAgent, Response
|
||||||
from ...messages import ChatMessage
|
from ...messages import ChatMessage
|
||||||
|
from ...state import ChatAgentContainerState
|
||||||
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
|
from ._events import GroupChatAgentResponse, GroupChatMessage, GroupChatRequestPublish, GroupChatReset, GroupChatStart
|
||||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||||
|
|
||||||
@ -75,3 +76,13 @@ class ChatAgentContainer(SequentialRoutedAgent):
|
|||||||
|
|
||||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
agent_state = await self._agent.save_state()
|
||||||
|
state = ChatAgentContainerState(agent_state=agent_state, message_buffer=list(self._message_buffer))
|
||||||
|
return state.model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
container_state = ChatAgentContainerState.model_validate(state)
|
||||||
|
self._message_buffer = list(container_state.message_buffer)
|
||||||
|
await self._agent.load_state(container_state.agent_state)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Mapping
|
||||||
|
|
||||||
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
|
from autogen_core import AgentId, CancellationToken, DefaultTopicId, Image, MessageContext, event, rpc
|
||||||
from autogen_core.components.models import (
|
from autogen_core.components.models import (
|
||||||
@ -13,6 +13,7 @@ from autogen_core.components.models import (
|
|||||||
from .... import TRACE_LOGGER_NAME
|
from .... import TRACE_LOGGER_NAME
|
||||||
from ....base import Response, TerminationCondition
|
from ....base import Response, TerminationCondition
|
||||||
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
from ....messages import AgentMessage, ChatMessage, MultiModalMessage, StopMessage, TextMessage
|
||||||
|
from ....state import MagenticOneOrchestratorState
|
||||||
from .._base_group_chat_manager import BaseGroupChatManager
|
from .._base_group_chat_manager import BaseGroupChatManager
|
||||||
from .._events import (
|
from .._events import (
|
||||||
GroupChatAgentResponse,
|
GroupChatAgentResponse,
|
||||||
@ -178,6 +179,28 @@ class MagenticOneOrchestrator(BaseGroupChatManager):
|
|||||||
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
async def validate_group_state(self, message: ChatMessage | None) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
state = MagenticOneOrchestratorState(
|
||||||
|
message_thread=list(self._message_thread),
|
||||||
|
current_turn=self._current_turn,
|
||||||
|
task=self._task,
|
||||||
|
facts=self._facts,
|
||||||
|
plan=self._plan,
|
||||||
|
n_rounds=self._n_rounds,
|
||||||
|
n_stalls=self._n_stalls,
|
||||||
|
)
|
||||||
|
return state.model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
orchestrator_state = MagenticOneOrchestratorState.model_validate(state)
|
||||||
|
self._message_thread = orchestrator_state.message_thread
|
||||||
|
self._current_turn = orchestrator_state.current_turn
|
||||||
|
self._task = orchestrator_state.task
|
||||||
|
self._facts = orchestrator_state.facts
|
||||||
|
self._plan = orchestrator_state.plan
|
||||||
|
self._n_rounds = orchestrator_state.n_rounds
|
||||||
|
self._n_stalls = orchestrator_state.n_stalls
|
||||||
|
|
||||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||||
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
"""Not used in this orchestrator, we select next speaker in _orchestrate_step."""
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from typing import Callable, List
|
from typing import Any, Callable, List, Mapping
|
||||||
|
|
||||||
from ...base import ChatAgent, TerminationCondition
|
from ...base import ChatAgent, TerminationCondition
|
||||||
from ...messages import AgentMessage, ChatMessage
|
from ...messages import AgentMessage, ChatMessage
|
||||||
|
from ...state import RoundRobinManagerState
|
||||||
from ._base_group_chat import BaseGroupChat
|
from ._base_group_chat import BaseGroupChat
|
||||||
from ._base_group_chat_manager import BaseGroupChatManager
|
from ._base_group_chat_manager import BaseGroupChatManager
|
||||||
|
|
||||||
@ -38,6 +39,20 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
|||||||
await self._termination_condition.reset()
|
await self._termination_condition.reset()
|
||||||
self._next_speaker_index = 0
|
self._next_speaker_index = 0
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
state = RoundRobinManagerState(
|
||||||
|
message_thread=list(self._message_thread),
|
||||||
|
current_turn=self._current_turn,
|
||||||
|
next_speaker_index=self._next_speaker_index,
|
||||||
|
)
|
||||||
|
return state.model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
round_robin_state = RoundRobinManagerState.model_validate(state)
|
||||||
|
self._message_thread = list(round_robin_state.message_thread)
|
||||||
|
self._current_turn = round_robin_state.current_turn
|
||||||
|
self._next_speaker_index = round_robin_state.next_speaker_index
|
||||||
|
|
||||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||||
"""Select a speaker from the participants in a round-robin fashion."""
|
"""Select a speaker from the participants in a round-robin fashion."""
|
||||||
current_speaker_index = self._next_speaker_index
|
current_speaker_index = self._next_speaker_index
|
||||||
|
|||||||
@ -1,10 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Callable, Dict, List, Sequence
|
from typing import Any, Callable, Dict, List, Mapping, Sequence
|
||||||
|
|
||||||
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||||
|
|
||||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
from ... import TRACE_LOGGER_NAME
|
||||||
from ...base import ChatAgent, TerminationCondition
|
from ...base import ChatAgent, TerminationCondition
|
||||||
from ...messages import (
|
from ...messages import (
|
||||||
AgentMessage,
|
AgentMessage,
|
||||||
@ -16,11 +16,11 @@ from ...messages import (
|
|||||||
ToolCallMessage,
|
ToolCallMessage,
|
||||||
ToolCallResultMessage,
|
ToolCallResultMessage,
|
||||||
)
|
)
|
||||||
|
from ...state import SelectorManagerState
|
||||||
from ._base_group_chat import BaseGroupChat
|
from ._base_group_chat import BaseGroupChat
|
||||||
from ._base_group_chat_manager import BaseGroupChatManager
|
from ._base_group_chat_manager import BaseGroupChatManager
|
||||||
|
|
||||||
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
trace_logger = logging.getLogger(TRACE_LOGGER_NAME)
|
||||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
class SelectorGroupChatManager(BaseGroupChatManager):
|
class SelectorGroupChatManager(BaseGroupChatManager):
|
||||||
@ -64,6 +64,20 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||||||
await self._termination_condition.reset()
|
await self._termination_condition.reset()
|
||||||
self._previous_speaker = None
|
self._previous_speaker = None
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
state = SelectorManagerState(
|
||||||
|
message_thread=list(self._message_thread),
|
||||||
|
current_turn=self._current_turn,
|
||||||
|
previous_speaker=self._previous_speaker,
|
||||||
|
)
|
||||||
|
return state.model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
selector_state = SelectorManagerState.model_validate(state)
|
||||||
|
self._message_thread = list(selector_state.message_thread)
|
||||||
|
self._current_turn = selector_state.current_turn
|
||||||
|
self._previous_speaker = selector_state.previous_speaker
|
||||||
|
|
||||||
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
async def select_speaker(self, thread: List[AgentMessage]) -> str:
|
||||||
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
"""Selects the next speaker in a group chat using a ChatCompletion client,
|
||||||
with the selector function as override if it returns a speaker name.
|
with the selector function as override if it returns a speaker name.
|
||||||
@ -121,7 +135,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
|||||||
select_speaker_prompt = self._selector_prompt.format(
|
select_speaker_prompt = self._selector_prompt.format(
|
||||||
roles=roles, participants=str(participants), history=history
|
roles=roles, participants=str(participants), history=history
|
||||||
)
|
)
|
||||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||||
response = await self._model_client.create(messages=select_speaker_messages)
|
response = await self._model_client.create(messages=select_speaker_messages)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
|
mentions = self._mentioned_agents(response.content, self._participant_topic_types)
|
||||||
|
|||||||
@ -1,14 +1,11 @@
|
|||||||
import logging
|
from typing import Any, Callable, List, Mapping
|
||||||
from typing import Callable, List
|
|
||||||
|
|
||||||
from ... import EVENT_LOGGER_NAME
|
|
||||||
from ...base import ChatAgent, TerminationCondition
|
from ...base import ChatAgent, TerminationCondition
|
||||||
from ...messages import AgentMessage, ChatMessage, HandoffMessage
|
from ...messages import AgentMessage, ChatMessage, HandoffMessage
|
||||||
|
from ...state import SwarmManagerState
|
||||||
from ._base_group_chat import BaseGroupChat
|
from ._base_group_chat import BaseGroupChat
|
||||||
from ._base_group_chat_manager import BaseGroupChatManager
|
from ._base_group_chat_manager import BaseGroupChatManager
|
||||||
|
|
||||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||||
@ -77,6 +74,20 @@ class SwarmGroupChatManager(BaseGroupChatManager):
|
|||||||
return self._current_speaker
|
return self._current_speaker
|
||||||
return self._current_speaker
|
return self._current_speaker
|
||||||
|
|
||||||
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
|
state = SwarmManagerState(
|
||||||
|
message_thread=list(self._message_thread),
|
||||||
|
current_turn=self._current_turn,
|
||||||
|
current_speaker=self._current_speaker,
|
||||||
|
)
|
||||||
|
return state.model_dump()
|
||||||
|
|
||||||
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
|
swarm_state = SwarmManagerState.model_validate(state)
|
||||||
|
self._message_thread = list(swarm_state.message_thread)
|
||||||
|
self._current_turn = swarm_state.current_turn
|
||||||
|
self._current_speaker = swarm_state.current_speaker
|
||||||
|
|
||||||
|
|
||||||
class Swarm(BaseGroupChat):
|
class Swarm(BaseGroupChat):
|
||||||
"""A group chat team that selects the next speaker based on handoff message only.
|
"""A group chat team that selects the next speaker based on handoff message only.
|
||||||
|
|||||||
@ -112,12 +112,12 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
]
|
]
|
||||||
mock = _MockChatCompletion(chat_completions)
|
mock = _MockChatCompletion(chat_completions)
|
||||||
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
||||||
tool_use_agent = AssistantAgent(
|
agent = AssistantAgent(
|
||||||
"tool_use_agent",
|
"tool_use_agent",
|
||||||
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
)
|
)
|
||||||
result = await tool_use_agent.run(task="task")
|
result = await agent.run(task="task")
|
||||||
assert len(result.messages) == 4
|
assert len(result.messages) == 4
|
||||||
assert isinstance(result.messages[0], TextMessage)
|
assert isinstance(result.messages[0], TextMessage)
|
||||||
assert result.messages[0].models_usage is None
|
assert result.messages[0].models_usage is None
|
||||||
@ -135,13 +135,24 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
# Test streaming.
|
# Test streaming.
|
||||||
mock._curr_index = 0 # pyright: ignore
|
mock._curr_index = 0 # pyright: ignore
|
||||||
index = 0
|
index = 0
|
||||||
async for message in tool_use_agent.run_stream(task="task"):
|
async for message in agent.run_stream(task="task"):
|
||||||
if isinstance(message, TaskResult):
|
if isinstance(message, TaskResult):
|
||||||
assert message == result
|
assert message == result
|
||||||
else:
|
else:
|
||||||
assert message == result.messages[index]
|
assert message == result.messages[index]
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
# Test state saving and loading.
|
||||||
|
state = await agent.save_state()
|
||||||
|
agent2 = AssistantAgent(
|
||||||
|
"tool_use_agent",
|
||||||
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
||||||
|
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
|
||||||
|
)
|
||||||
|
await agent2.load_state(state)
|
||||||
|
state2 = await agent2.save_state()
|
||||||
|
assert state == state2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|||||||
@ -28,12 +28,15 @@ from autogen_agentchat.teams import (
|
|||||||
SelectorGroupChat,
|
SelectorGroupChat,
|
||||||
Swarm,
|
Swarm,
|
||||||
)
|
)
|
||||||
|
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
|
||||||
|
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
||||||
|
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
||||||
from autogen_agentchat.ui import Console
|
from autogen_agentchat.ui import Console
|
||||||
from autogen_core import CancellationToken, FunctionCall
|
from autogen_core import AgentId, CancellationToken, FunctionCall
|
||||||
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
from autogen_core.components.code_executor import LocalCommandLineCodeExecutor
|
||||||
from autogen_core.components.models import FunctionExecutionResult
|
from autogen_core.components.models import FunctionExecutionResult
|
||||||
from autogen_core.components.tools import FunctionTool
|
from autogen_core.components.tools import FunctionTool
|
||||||
from autogen_ext.models import OpenAIChatCompletionClient
|
from autogen_ext.models import OpenAIChatCompletionClient, ReplayChatCompletionClient
|
||||||
from openai.resources.chat.completions import AsyncCompletions
|
from openai.resources.chat.completions import AsyncCompletions
|
||||||
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
@ -217,6 +220,38 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert result.messages[1:] == result_2.messages[1:]
|
assert result.messages[1:] == result_2.messages[1:]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_round_robin_group_chat_state() -> None:
|
||||||
|
model_client = ReplayChatCompletionClient(
|
||||||
|
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
||||||
|
)
|
||||||
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||||
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||||
|
termination = TextMentionTermination("TERMINATE")
|
||||||
|
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
||||||
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||||
|
state = await team1.save_state()
|
||||||
|
|
||||||
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||||
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||||
|
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
|
||||||
|
await team2.load_state(state)
|
||||||
|
state2 = await team2.save_state()
|
||||||
|
assert state == state2
|
||||||
|
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||||
|
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||||
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||||
|
RoundRobinGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||||
|
RoundRobinGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
|
||||||
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
model = "gpt-4o-2024-05-13"
|
model = "gpt-4o-2024-05-13"
|
||||||
@ -528,6 +563,42 @@ async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
assert result2 == result
|
assert result2 == result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_selector_group_chat_state() -> None:
|
||||||
|
model_client = ReplayChatCompletionClient(
|
||||||
|
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
||||||
|
)
|
||||||
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
||||||
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
||||||
|
termination = TextMentionTermination("TERMINATE")
|
||||||
|
team1 = SelectorGroupChat(
|
||||||
|
participants=[agent1, agent2], termination_condition=termination, model_client=model_client
|
||||||
|
)
|
||||||
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
||||||
|
state = await team1.save_state()
|
||||||
|
|
||||||
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
||||||
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
||||||
|
team2 = SelectorGroupChat(
|
||||||
|
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
||||||
|
)
|
||||||
|
await team2.load_state(state)
|
||||||
|
state2 = await team2.save_state()
|
||||||
|
assert state == state2
|
||||||
|
assert agent3._model_context == agent1._model_context # pyright: ignore
|
||||||
|
assert agent4._model_context == agent2._model_context # pyright: ignore
|
||||||
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
||||||
|
SelectorGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||||
|
SelectorGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||||
|
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
model = "gpt-4o-2024-05-13"
|
model = "gpt-4o-2024-05-13"
|
||||||
@ -768,6 +839,26 @@ async def test_swarm_handoff() -> None:
|
|||||||
assert message == result.messages[index]
|
assert message == result.messages[index]
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
|
# Test save and load.
|
||||||
|
state = await team.save_state()
|
||||||
|
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||||
|
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||||
|
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||||
|
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
|
||||||
|
await team2.load_state(state)
|
||||||
|
state2 = await team2.save_state()
|
||||||
|
assert state == state2
|
||||||
|
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||||
|
SwarmGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||||
|
SwarmGroupChatManager, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||||
|
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|||||||
@ -16,7 +16,8 @@ from autogen_agentchat.messages import (
|
|||||||
from autogen_agentchat.teams import (
|
from autogen_agentchat.teams import (
|
||||||
MagenticOneGroupChat,
|
MagenticOneGroupChat,
|
||||||
)
|
)
|
||||||
from autogen_core import CancellationToken
|
from autogen_agentchat.teams._group_chat._magentic_one._magentic_one_orchestrator import MagenticOneOrchestrator
|
||||||
|
from autogen_core import AgentId, CancellationToken
|
||||||
from autogen_ext.models import ReplayChatCompletionClient
|
from autogen_ext.models import ReplayChatCompletionClient
|
||||||
from utils import FileLogHandler
|
from utils import FileLogHandler
|
||||||
|
|
||||||
@ -121,6 +122,27 @@ async def test_magentic_one_group_chat_basic() -> None:
|
|||||||
assert result.messages[4].content == "print('Hello, world!')"
|
assert result.messages[4].content == "print('Hello, world!')"
|
||||||
assert result.stop_reason is not None and result.stop_reason == "Because"
|
assert result.stop_reason is not None and result.stop_reason == "Because"
|
||||||
|
|
||||||
|
# Test save and load.
|
||||||
|
state = await team.save_state()
|
||||||
|
team2 = MagenticOneGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], model_client=model_client)
|
||||||
|
await team2.load_state(state)
|
||||||
|
state2 = await team2.save_state()
|
||||||
|
assert state == state2
|
||||||
|
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
||||||
|
MagenticOneOrchestrator, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
||||||
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
||||||
|
MagenticOneOrchestrator, # pyright: ignore
|
||||||
|
) # pyright: ignore
|
||||||
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
||||||
|
assert manager_1._task == manager_2._task # pyright: ignore
|
||||||
|
assert manager_1._facts == manager_2._facts # pyright: ignore
|
||||||
|
assert manager_1._plan == manager_2._plan # pyright: ignore
|
||||||
|
assert manager_1._n_rounds == manager_2._n_rounds # pyright: ignore
|
||||||
|
assert manager_1._n_stalls == manager_2._n_stalls # pyright: ignore
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_magentic_one_group_chat_with_stalls() -> None:
|
async def test_magentic_one_group_chat_with_stalls() -> None:
|
||||||
|
|||||||
@ -48,6 +48,12 @@ A dynamic team that uses handoffs to pass tasks between agents.
|
|||||||
How to build custom agents.
|
How to build custom agents.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{grid-item-card} {fas}`users;pst-color-primary` State Management
|
||||||
|
:link: ./state.html
|
||||||
|
|
||||||
|
How to manage state in agents and teams.
|
||||||
|
:::
|
||||||
|
|
||||||
::::
|
::::
|
||||||
|
|
||||||
```{toctree}
|
```{toctree}
|
||||||
@ -61,4 +67,5 @@ selector-group-chat
|
|||||||
swarm
|
swarm
|
||||||
termination
|
termination
|
||||||
custom-agents
|
custom-agents
|
||||||
|
state
|
||||||
```
|
```
|
||||||
|
|||||||
@ -0,0 +1,299 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Managing State \n",
|
||||||
|
"\n",
|
||||||
|
"So far, we have discussed how to build components in a multi-agent application - agents, teams, termination conditions. In many cases, it is useful to save the state of these components to disk and load them back later. This is particularly useful in a web application where stateless endpoints respond to requests and need to load the state of the application from persistent storage.\n",
|
||||||
|
"\n",
|
||||||
|
"In this notebook, we will discuss how to save and load the state of agents, teams, and termination conditions. \n",
|
||||||
|
" \n",
|
||||||
|
"\n",
|
||||||
|
"## Saving and Loading Agents\n",
|
||||||
|
"\n",
|
||||||
|
"We can get the state of an agent by calling {py:meth}`~autogen_agentchat.agents.AssistantAgent.save_state` method on \n",
|
||||||
|
"an {py:class}`~autogen_agentchat.agents.AssistantAgent`. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"In Tanganyika's depths so wide and deep, \n",
|
||||||
|
"Ancient secrets in still waters sleep, \n",
|
||||||
|
"Ripples tell tales that time longs to keep. \n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from autogen_agentchat.agents import AssistantAgent\n",
|
||||||
|
"from autogen_agentchat.conditions import MaxMessageTermination\n",
|
||||||
|
"from autogen_agentchat.messages import TextMessage\n",
|
||||||
|
"from autogen_agentchat.teams import RoundRobinGroupChat\n",
|
||||||
|
"from autogen_agentchat.ui import Console\n",
|
||||||
|
"from autogen_core import CancellationToken\n",
|
||||||
|
"from autogen_ext.models import OpenAIChatCompletionClient\n",
|
||||||
|
"\n",
|
||||||
|
"assistant_agent = AssistantAgent(\n",
|
||||||
|
" name=\"assistant_agent\",\n",
|
||||||
|
" system_message=\"You are a helpful assistant\",\n",
|
||||||
|
" model_client=OpenAIChatCompletionClient(\n",
|
||||||
|
" model=\"gpt-4o-2024-08-06\",\n",
|
||||||
|
" # api_key=\"YOUR_API_KEY\",\n",
|
||||||
|
" ),\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"# Use asyncio.run(...) when running in a script.\n",
|
||||||
|
"response = await assistant_agent.on_messages(\n",
|
||||||
|
" [TextMessage(content=\"Write a 3 line poem on lake tangayika\", source=\"user\")], CancellationToken()\n",
|
||||||
|
")\n",
|
||||||
|
"print(response.chat_message.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a 3 line poem on lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths so wide and deep, \\nAncient secrets in still waters sleep, \\nRipples tell tales that time longs to keep. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"agent_state = await assistant_agent.save_state()\n",
|
||||||
|
"print(agent_state)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"The last line of the poem I wrote was: \n",
|
||||||
|
"\"Ripples tell tales that time longs to keep.\"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"new_assistant_agent = AssistantAgent(\n",
|
||||||
|
" name=\"assistant_agent\",\n",
|
||||||
|
" system_message=\"You are a helpful assistant\",\n",
|
||||||
|
" model_client=OpenAIChatCompletionClient(\n",
|
||||||
|
" model=\"gpt-4o-2024-08-06\",\n",
|
||||||
|
" ),\n",
|
||||||
|
")\n",
|
||||||
|
"await new_assistant_agent.load_state(agent_state)\n",
|
||||||
|
"\n",
|
||||||
|
"# Use asyncio.run(...) when running in a script.\n",
|
||||||
|
"response = await new_assistant_agent.on_messages(\n",
|
||||||
|
" [TextMessage(content=\"What was the last line of the previous poem you wrote\", source=\"user\")], CancellationToken()\n",
|
||||||
|
")\n",
|
||||||
|
"print(response.chat_message.content)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"```{note}\n",
|
||||||
|
"For {py:class}`~autogen_agentchat.agents.AssistantAgent`, its state consists of the model_context.\n",
|
||||||
|
"If your write your own custom agent, consider overriding the {py:meth}`~autogen_agentchat.agents.BaseChatAgent.save_state` and {py:meth}`~autogen_agentchat.agents.BaseChatAgent.load_state` methods to customize the behavior. The default implementations save and load an empty state.\n",
|
||||||
|
"```"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Saving and Loading Teams \n",
|
||||||
|
"\n",
|
||||||
|
"We can get the state of a team by calling `save_state` method on the team and load it back by calling `load_state` method on the team. \n",
|
||||||
|
"\n",
|
||||||
|
"When we call `save_state` on a team, it saves the state of all the agents in the team.\n",
|
||||||
|
"\n",
|
||||||
|
"We will begin by creating a simple {py:class}`~autogen_agentchat.teams.RoundRobinGroupChat` team with a single agent and ask it to write a poem. "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"---------- user ----------\n",
|
||||||
|
"Write a beautiful poem 3-line about lake tangayika\n",
|
||||||
|
"---------- assistant_agent ----------\n",
|
||||||
|
"In Tanganyika's depths, where light gently weaves, \n",
|
||||||
|
"Silver reflections dance on ancient water's face, \n",
|
||||||
|
"Whispered stories of time in the rippling leaves. \n",
|
||||||
|
"[Prompt tokens: 29, Completion tokens: 36]\n",
|
||||||
|
"---------- Summary ----------\n",
|
||||||
|
"Number of messages: 2\n",
|
||||||
|
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||||
|
"Total prompt tokens: 29\n",
|
||||||
|
"Total completion tokens: 36\n",
|
||||||
|
"Duration: 1.16 seconds\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Define a team.\n",
|
||||||
|
"assistant_agent = AssistantAgent(\n",
|
||||||
|
" name=\"assistant_agent\",\n",
|
||||||
|
" system_message=\"You are a helpful assistant\",\n",
|
||||||
|
" model_client=OpenAIChatCompletionClient(\n",
|
||||||
|
" model=\"gpt-4o-2024-08-06\",\n",
|
||||||
|
" ),\n",
|
||||||
|
")\n",
|
||||||
|
"agent_team = RoundRobinGroupChat([assistant_agent], termination_condition=MaxMessageTermination(max_messages=2))\n",
|
||||||
|
"\n",
|
||||||
|
"# Run the team and stream messages to the console.\n",
|
||||||
|
"stream = agent_team.run_stream(task=\"Write a beautiful poem 3-line about lake tangayika\")\n",
|
||||||
|
"\n",
|
||||||
|
"# Use asyncio.run(...) when running in a script.\n",
|
||||||
|
"await Console(stream)\n",
|
||||||
|
"\n",
|
||||||
|
"# Save the state of the agent team.\n",
|
||||||
|
"team_state = await agent_team.save_state()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If we reset the team (simulating instantiation of the team), and ask the question `What was the last line of the poem you wrote?`, we see that the team is unable to accomplish this as there is no reference to the previous run."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"---------- user ----------\n",
|
||||||
|
"What was the last line of the poem you wrote?\n",
|
||||||
|
"---------- assistant_agent ----------\n",
|
||||||
|
"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\n",
|
||||||
|
"[Prompt tokens: 28, Completion tokens: 39]\n",
|
||||||
|
"---------- Summary ----------\n",
|
||||||
|
"Number of messages: 2\n",
|
||||||
|
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||||
|
"Total prompt tokens: 28\n",
|
||||||
|
"Total completion tokens: 39\n",
|
||||||
|
"Duration: 0.95 seconds\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=28, completion_tokens=39), type='TextMessage', content=\"I don't write poems on my own, but I can help create one with you or try to recall a specific poem if you have one in mind. Let me know what you'd like to do!\")], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await agent_team.reset()\n",
|
||||||
|
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||||
|
"await Console(stream)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Next, we load the state of the team and ask the same question. We see that the team is able to accurately return the last line of the poem it wrote.\n",
|
||||||
|
"\n",
|
||||||
|
"Note: You can serialize the state of the team to a file and load it back later."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{'type': 'TeamState', 'version': '1.0.0', 'agent_states': {'group_chat_manager/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'RoundRobinManagerState', 'version': '1.0.0', 'message_thread': [{'source': 'user', 'models_usage': None, 'type': 'TextMessage', 'content': 'Write a beautiful poem 3-line about lake tangayika'}, {'source': 'assistant_agent', 'models_usage': {'prompt_tokens': 29, 'completion_tokens': 36}, 'type': 'TextMessage', 'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \"}], 'current_turn': 0, 'next_speaker_index': 0}, 'collect_output_messages/c80054be-efb2-4bc7-ba0d-900962092c44': {}, 'assistant_agent/c80054be-efb2-4bc7-ba0d-900962092c44': {'type': 'ChatAgentContainerState', 'version': '1.0.0', 'agent_state': {'type': 'AssistantAgentState', 'version': '1.0.0', 'llm_messages': [{'content': 'Write a beautiful poem 3-line about lake tangayika', 'source': 'user', 'type': 'UserMessage'}, {'content': \"In Tanganyika's depths, where light gently weaves, \\nSilver reflections dance on ancient water's face, \\nWhispered stories of time in the rippling leaves. \", 'source': 'assistant_agent', 'type': 'AssistantMessage'}]}, 'message_buffer': []}}, 'team_id': 'c80054be-efb2-4bc7-ba0d-900962092c44'}\n",
|
||||||
|
"---------- user ----------\n",
|
||||||
|
"What was the last line of the poem you wrote?\n",
|
||||||
|
"---------- assistant_agent ----------\n",
|
||||||
|
"The last line of the poem I wrote was: \n",
|
||||||
|
"\"Whispered stories of time in the rippling leaves.\"\n",
|
||||||
|
"[Prompt tokens: 88, Completion tokens: 24]\n",
|
||||||
|
"---------- Summary ----------\n",
|
||||||
|
"Number of messages: 2\n",
|
||||||
|
"Finish reason: Maximum number of messages 2 reached, current message count: 2\n",
|
||||||
|
"Total prompt tokens: 88\n",
|
||||||
|
"Total completion tokens: 24\n",
|
||||||
|
"Duration: 0.79 seconds\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"TaskResult(messages=[TextMessage(source='user', models_usage=None, type='TextMessage', content='What was the last line of the poem you wrote?'), TextMessage(source='assistant_agent', models_usage=RequestUsage(prompt_tokens=88, completion_tokens=24), type='TextMessage', content='The last line of the poem I wrote was: \\n\"Whispered stories of time in the rippling leaves.\"')], stop_reason='Maximum number of messages 2 reached, current message count: 2')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(team_state)\n",
|
||||||
|
"\n",
|
||||||
|
"# Load team state.\n",
|
||||||
|
"await agent_team.load_state(team_state)\n",
|
||||||
|
"stream = agent_team.run_stream(task=\"What was the last line of the poem you wrote?\")\n",
|
||||||
|
"await Console(stream)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": ".venv",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@ -213,7 +213,7 @@
|
|||||||
" \"tool_enabled_agent\",\n",
|
" \"tool_enabled_agent\",\n",
|
||||||
" lambda: ToolUseAgent(\n",
|
" lambda: ToolUseAgent(\n",
|
||||||
" description=\"Tool Use Agent\",\n",
|
" description=\"Tool Use Agent\",\n",
|
||||||
" system_messages=[SystemMessage(\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
" system_messages=[SystemMessage(content=\"You are a helpful AI Assistant. Use your tools to solve problems.\")],\n",
|
||||||
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
" model_client=OpenAIChatCompletionClient(model=\"gpt-4o-mini\"),\n",
|
||||||
" tool_schema=[python_tool.schema],\n",
|
" tool_schema=[python_tool.schema],\n",
|
||||||
" tool_agent_type=tool_agent_type,\n",
|
" tool_agent_type=tool_agent_type,\n",
|
||||||
|
|||||||
@ -174,7 +174,7 @@
|
|||||||
" factory=lambda: TaxSpecialist(\n",
|
" factory=lambda: TaxSpecialist(\n",
|
||||||
" description=\"A tax specialist 1\",\n",
|
" description=\"A tax specialist 1\",\n",
|
||||||
" specialty=TaxSpecialty.PLANNING,\n",
|
" specialty=TaxSpecialty.PLANNING,\n",
|
||||||
" system_messages=[SystemMessage(\"You are a tax specialist.\")],\n",
|
" system_messages=[SystemMessage(content=\"You are a tax specialist.\")],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -184,7 +184,7 @@
|
|||||||
" factory=lambda: TaxSpecialist(\n",
|
" factory=lambda: TaxSpecialist(\n",
|
||||||
" description=\"A tax specialist 2\",\n",
|
" description=\"A tax specialist 2\",\n",
|
||||||
" specialty=TaxSpecialty.DISPUTE_RESOLUTION,\n",
|
" specialty=TaxSpecialty.DISPUTE_RESOLUTION,\n",
|
||||||
" system_messages=[SystemMessage(\"You are a tax specialist.\")],\n",
|
" system_messages=[SystemMessage(content=\"You are a tax specialist.\")],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -303,7 +303,7 @@
|
|||||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||||
" specialty=specialty,\n",
|
" specialty=specialty,\n",
|
||||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" specialist_subscription = DefaultSubscription(agent_type=specialist_agent_type)\n",
|
" specialist_subscription = DefaultSubscription(agent_type=specialist_agent_type)\n",
|
||||||
@ -414,7 +414,7 @@
|
|||||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||||
" specialty=specialty,\n",
|
" specialty=specialty,\n",
|
||||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" specialist_subscription = TypeSubscription(topic_type=specialty.value, agent_type=specialist_agent_type)\n",
|
" specialist_subscription = TypeSubscription(topic_type=specialty.value, agent_type=specialist_agent_type)\n",
|
||||||
@ -545,7 +545,7 @@
|
|||||||
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
" factory=lambda specialty=specialty: TaxSpecialist( # type: ignore\n",
|
||||||
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
" description=f\"A tax specialist in {specialty.value}.\",\n",
|
||||||
" specialty=specialty,\n",
|
" specialty=specialty,\n",
|
||||||
" system_messages=[SystemMessage(f\"You are a tax specialist in {specialty.value}.\")],\n",
|
" system_messages=[SystemMessage(content=f\"You are a tax specialist in {specialty.value}.\")],\n",
|
||||||
" ),\n",
|
" ),\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" for tenant in tenants:\n",
|
" for tenant in tenants:\n",
|
||||||
|
|||||||
@ -158,7 +158,7 @@
|
|||||||
" super().__init__(description=description)\n",
|
" super().__init__(description=description)\n",
|
||||||
" self._group_chat_topic_type = group_chat_topic_type\n",
|
" self._group_chat_topic_type = group_chat_topic_type\n",
|
||||||
" self._model_client = model_client\n",
|
" self._model_client = model_client\n",
|
||||||
" self._system_message = SystemMessage(system_message)\n",
|
" self._system_message = SystemMessage(content=system_message)\n",
|
||||||
" self._chat_history: List[LLMMessage] = []\n",
|
" self._chat_history: List[LLMMessage] = []\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @message_handler\n",
|
" @message_handler\n",
|
||||||
@ -427,7 +427,7 @@
|
|||||||
"Read the above conversation. Then select the next role from {participants} to play. Only return the role.\n",
|
"Read the above conversation. Then select the next role from {participants} to play. Only return the role.\n",
|
||||||
"\"\"\"\n",
|
"\"\"\"\n",
|
||||||
" system_message = SystemMessage(\n",
|
" system_message = SystemMessage(\n",
|
||||||
" selector_prompt.format(\n",
|
" content=selector_prompt.format(\n",
|
||||||
" roles=roles,\n",
|
" roles=roles,\n",
|
||||||
" history=history,\n",
|
" history=history,\n",
|
||||||
" participants=str(\n",
|
" participants=str(\n",
|
||||||
|
|||||||
@ -112,7 +112,7 @@
|
|||||||
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
||||||
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(message.previous_results)])\n",
|
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(message.previous_results)])\n",
|
||||||
" model_result = await self._model_client.create(\n",
|
" model_result = await self._model_client.create(\n",
|
||||||
" [SystemMessage(system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
" [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" else:\n",
|
" else:\n",
|
||||||
" # If no previous results are provided, we can simply pass the user query to the model.\n",
|
" # If no previous results are provided, we can simply pass the user query to the model.\n",
|
||||||
@ -174,7 +174,7 @@
|
|||||||
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
" system_prompt = \"You have been provided with a set of responses from various open-source models to the latest user query. Your task is to synthesize these responses into a single, high-quality response. It is crucial to critically evaluate the information provided in these responses, recognizing that some of it may be biased or incorrect. Your response should not simply replicate the given answers but should offer a refined, accurate, and comprehensive reply to the instruction. Ensure your response is well-structured, coherent, and adheres to the highest standards of accuracy and reliability.\\n\\nResponses from models:\"\n",
|
||||||
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(worker_task.previous_results)])\n",
|
" system_prompt += \"\\n\" + \"\\n\\n\".join([f\"{i+1}. {r}\" for i, r in enumerate(worker_task.previous_results)])\n",
|
||||||
" model_result = await self._model_client.create(\n",
|
" model_result = await self._model_client.create(\n",
|
||||||
" [SystemMessage(system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
" [SystemMessage(content=system_prompt), UserMessage(content=message.task, source=\"user\")]\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" assert isinstance(model_result.content, str)\n",
|
" assert isinstance(model_result.content, str)\n",
|
||||||
" return FinalResult(result=model_result.content)"
|
" return FinalResult(result=model_result.content)"
|
||||||
|
|||||||
@ -146,7 +146,7 @@
|
|||||||
" self._buffer: Dict[int, List[IntermediateSolverResponse]] = {}\n",
|
" self._buffer: Dict[int, List[IntermediateSolverResponse]] = {}\n",
|
||||||
" self._system_messages = [\n",
|
" self._system_messages = [\n",
|
||||||
" SystemMessage(\n",
|
" SystemMessage(\n",
|
||||||
" (\n",
|
" content=(\n",
|
||||||
" \"You are a helpful assistant with expertise in mathematics and reasoning. \"\n",
|
" \"You are a helpful assistant with expertise in mathematics and reasoning. \"\n",
|
||||||
" \"Your task is to assist in solving a math reasoning problem by providing \"\n",
|
" \"Your task is to assist in solving a math reasoning problem by providing \"\n",
|
||||||
" \"a clear and detailed solution. Limit your output within 100 words, \"\n",
|
" \"a clear and detailed solution. Limit your output within 100 words, \"\n",
|
||||||
|
|||||||
@ -343,7 +343,7 @@
|
|||||||
"class SimpleAgent(RoutedAgent):\n",
|
"class SimpleAgent(RoutedAgent):\n",
|
||||||
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
||||||
" super().__init__(\"A simple agent\")\n",
|
" super().__init__(\"A simple agent\")\n",
|
||||||
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
" self._system_messages = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||||
" self._model_client = model_client\n",
|
" self._model_client = model_client\n",
|
||||||
"\n",
|
"\n",
|
||||||
" @message_handler\n",
|
" @message_handler\n",
|
||||||
@ -478,7 +478,7 @@
|
|||||||
"class SimpleAgentWithContext(RoutedAgent):\n",
|
"class SimpleAgentWithContext(RoutedAgent):\n",
|
||||||
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
" def __init__(self, model_client: ChatCompletionClient) -> None:\n",
|
||||||
" super().__init__(\"A simple agent\")\n",
|
" super().__init__(\"A simple agent\")\n",
|
||||||
" self._system_messages = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
" self._system_messages = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||||
" self._model_client = model_client\n",
|
" self._model_client = model_client\n",
|
||||||
" self._model_context = BufferedChatCompletionContext(buffer_size=5)\n",
|
" self._model_context = BufferedChatCompletionContext(buffer_size=5)\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
|||||||
@ -176,7 +176,7 @@
|
|||||||
"class ToolUseAgent(RoutedAgent):\n",
|
"class ToolUseAgent(RoutedAgent):\n",
|
||||||
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
" def __init__(self, model_client: ChatCompletionClient, tool_schema: List[ToolSchema], tool_agent_type: str) -> None:\n",
|
||||||
" super().__init__(\"An agent with tools\")\n",
|
" super().__init__(\"An agent with tools\")\n",
|
||||||
" self._system_messages: List[LLMMessage] = [SystemMessage(\"You are a helpful AI assistant.\")]\n",
|
" self._system_messages: List[LLMMessage] = [SystemMessage(content=\"You are a helpful AI assistant.\")]\n",
|
||||||
" self._model_client = model_client\n",
|
" self._model_client = model_client\n",
|
||||||
" self._tool_schema = tool_schema\n",
|
" self._tool_schema = tool_schema\n",
|
||||||
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
" self._tool_agent_id = AgentId(tool_agent_type, self.id.key)\n",
|
||||||
|
|||||||
@ -36,7 +36,7 @@ Read the following conversation. Then select the next role from {participants} t
|
|||||||
|
|
||||||
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
Read the above conversation. Then select the next role from {participants} to play. Only return the role.
|
||||||
"""
|
"""
|
||||||
select_speaker_messages = [SystemMessage(select_speaker_prompt)]
|
select_speaker_messages = [SystemMessage(content=select_speaker_prompt)]
|
||||||
response = await client.create(messages=select_speaker_messages)
|
response = await client.create(messages=select_speaker_messages)
|
||||||
assert isinstance(response.content, str)
|
assert isinstance(response.content, str)
|
||||||
mentions = await mentioned_agents(response.content, agents)
|
mentions = await mentioned_agents(response.content, agents)
|
||||||
|
|||||||
@ -92,7 +92,7 @@ def convert_messages_to_llm_messages(
|
|||||||
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
converted_message_2 = convert_content_message_to_user_message(message, handle_unrepresentable)
|
||||||
if converted_message_2 is not None:
|
if converted_message_2 is not None:
|
||||||
result.append(converted_message_2)
|
result.append(converted_message_2)
|
||||||
case FunctionExecutionResultMessage(_):
|
case FunctionExecutionResultMessage(content=_):
|
||||||
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
converted_message_3 = convert_tool_call_response_message(message, handle_unrepresentable)
|
||||||
if converted_message_3 is not None:
|
if converted_message_3 is not None:
|
||||||
result.append(converted_message_3)
|
result.append(converted_message_3)
|
||||||
|
|||||||
@ -31,7 +31,7 @@ class BaseGroupChatAgent(RoutedAgent):
|
|||||||
super().__init__(description=description)
|
super().__init__(description=description)
|
||||||
self._group_chat_topic_type = group_chat_topic_type
|
self._group_chat_topic_type = group_chat_topic_type
|
||||||
self._model_client = model_client
|
self._model_client = model_client
|
||||||
self._system_message = SystemMessage(system_message)
|
self._system_message = SystemMessage(content=system_message)
|
||||||
self._chat_history: List[LLMMessage] = []
|
self._chat_history: List[LLMMessage] = []
|
||||||
self._ui_config = ui_config
|
self._ui_config = ui_config
|
||||||
self.console = Console()
|
self.console = Console()
|
||||||
@ -126,7 +126,7 @@ Read the following conversation. Then select the next role from {participants} t
|
|||||||
|
|
||||||
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
|
Read the above conversation. Then select the next role from {participants} to play. if you think it's enough talking (for example they have talked for {self._max_rounds} rounds), return 'FINISH'.
|
||||||
"""
|
"""
|
||||||
system_message = SystemMessage(selector_prompt)
|
system_message = SystemMessage(content=selector_prompt)
|
||||||
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
|
completion = await self._model_client.create([system_message], cancellation_token=ctx.cancellation_token)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
|
|||||||
@ -160,12 +160,14 @@ class SchedulingAssistantAgent(RoutedAgent):
|
|||||||
self._name = name
|
self._name = name
|
||||||
self._model_client = model_client
|
self._model_client = model_client
|
||||||
self._system_messages = [
|
self._system_messages = [
|
||||||
SystemMessage(f"""
|
SystemMessage(
|
||||||
|
content=f"""
|
||||||
I am a helpful AI assistant that helps schedule meetings.
|
I am a helpful AI assistant that helps schedule meetings.
|
||||||
If there are missing parameters, I will ask for them.
|
If there are missing parameters, I will ask for them.
|
||||||
|
|
||||||
Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
Today's date is {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@message_handler
|
@message_handler
|
||||||
|
|||||||
@ -116,10 +116,12 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
|||||||
return await self._closure(self, message, ctx)
|
return await self._closure(self, message, ctx)
|
||||||
|
|
||||||
async def save_state(self) -> Mapping[str, Any]:
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
raise ValueError("save_state not implemented for ClosureAgent")
|
"""Closure agents do not have state. So this method always returns an empty dictionary."""
|
||||||
|
return {}
|
||||||
|
|
||||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||||
raise ValueError("load_state not implemented for ClosureAgent")
|
"""Closure agents do not have state. So this method does nothing."""
|
||||||
|
pass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def register_closure(
|
async def register_closure(
|
||||||
|
|||||||
@ -1,42 +1,49 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Literal, Optional, Union
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from ... import FunctionCall, Image
|
from ... import FunctionCall, Image
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class SystemMessage(BaseModel):
|
||||||
class SystemMessage:
|
|
||||||
content: str
|
content: str
|
||||||
|
type: Literal["SystemMessage"] = "SystemMessage"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class UserMessage(BaseModel):
|
||||||
class UserMessage:
|
|
||||||
content: Union[str, List[Union[str, Image]]]
|
content: Union[str, List[Union[str, Image]]]
|
||||||
|
|
||||||
# Name of the agent that sent this message
|
# Name of the agent that sent this message
|
||||||
source: str
|
source: str
|
||||||
|
|
||||||
|
type: Literal["UserMessage"] = "UserMessage"
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AssistantMessage:
|
class AssistantMessage(BaseModel):
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
|
|
||||||
# Name of the agent that sent this message
|
# Name of the agent that sent this message
|
||||||
source: str
|
source: str
|
||||||
|
|
||||||
|
type: Literal["AssistantMessage"] = "AssistantMessage"
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FunctionExecutionResult:
|
class FunctionExecutionResult(BaseModel):
|
||||||
content: str
|
content: str
|
||||||
call_id: str
|
call_id: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class FunctionExecutionResultMessage(BaseModel):
|
||||||
class FunctionExecutionResultMessage:
|
|
||||||
content: List[FunctionExecutionResult]
|
content: List[FunctionExecutionResult]
|
||||||
|
|
||||||
|
type: Literal["FunctionExecutionResultMessage"] = "FunctionExecutionResultMessage"
|
||||||
|
|
||||||
LLMMessage = Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage]
|
|
||||||
|
LLMMessage = Annotated[
|
||||||
|
Union[SystemMessage, UserMessage, AssistantMessage, FunctionExecutionResultMessage], Field(discriminator="type")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -54,16 +61,14 @@ class TopLogprob:
|
|||||||
bytes: Optional[List[int]] = None
|
bytes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class ChatCompletionTokenLogprob(BaseModel):
|
||||||
class ChatCompletionTokenLogprob:
|
|
||||||
token: str
|
token: str
|
||||||
logprob: float
|
logprob: float
|
||||||
top_logprobs: Optional[List[TopLogprob] | None] = None
|
top_logprobs: Optional[List[TopLogprob] | None] = None
|
||||||
bytes: Optional[List[int]] = None
|
bytes: Optional[List[int]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
class CreateResult(BaseModel):
|
||||||
class CreateResult:
|
|
||||||
finish_reason: FinishReasons
|
finish_reason: FinishReasons
|
||||||
content: Union[str, List[FunctionCall]]
|
content: Union[str, List[FunctionCall]]
|
||||||
usage: RequestUsage
|
usage: RequestUsage
|
||||||
|
|||||||
@ -36,9 +36,11 @@ class FileSurfer(BaseChatAgent):
|
|||||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||||
|
|
||||||
DEFAULT_SYSTEM_MESSAGES = [
|
DEFAULT_SYSTEM_MESSAGES = [
|
||||||
SystemMessage("""
|
SystemMessage(
|
||||||
|
content="""
|
||||||
You are a helpful AI Assistant.
|
You are a helpful AI Assistant.
|
||||||
When given a user query, use available functions to help the user with their request."""),
|
When given a user query, use available functions to help the user with their request."""
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -78,7 +80,7 @@ class FileSurfer(BaseChatAgent):
|
|||||||
|
|
||||||
except BaseException:
|
except BaseException:
|
||||||
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
content = f"File surfing error:\n\n{traceback.format_exc()}"
|
||||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||||
|
|
||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
|
|||||||
@ -190,7 +190,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
|||||||
|
|
||||||
except BaseException:
|
except BaseException:
|
||||||
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
content = f"Web surfing error:\n\n{traceback.format_exc()}"
|
||||||
self._chat_history.append(AssistantMessage(content, source=self.name))
|
self._chat_history.append(AssistantMessage(content=content, source=self.name))
|
||||||
return Response(chat_message=TextMessage(content=content, source=self.name))
|
return Response(chat_message=TextMessage(content=content, source=self.name))
|
||||||
|
|
||||||
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
||||||
@ -712,7 +712,7 @@ class MultimodalWebSurfer(BaseChatAgent):
|
|||||||
for line in page_markdown.splitlines():
|
for line in page_markdown.splitlines():
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
# content=[
|
# content=[
|
||||||
prompt + buffer + line,
|
content=prompt + buffer + line,
|
||||||
# ag_image,
|
# ag_image,
|
||||||
# ],
|
# ],
|
||||||
source=self.name,
|
source=self.name,
|
||||||
|
|||||||
@ -33,7 +33,7 @@ class LLMAgent(RoutedAgent):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _fixed_message_history_type(self) -> List[SystemMessage]:
|
def _fixed_message_history_type(self) -> List[SystemMessage]:
|
||||||
return [SystemMessage(msg.content) for msg in self._chat_history]
|
return [SystemMessage(content=msg.content) for msg in self._chat_history]
|
||||||
|
|
||||||
|
|
||||||
@default_subscription
|
@default_subscription
|
||||||
|
|||||||
@ -21,7 +21,8 @@ class Coder(BaseWorker):
|
|||||||
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
DEFAULT_DESCRIPTION = "A helpful and general-purpose AI assistant that has strong language skills, Python skills, and Linux command line skills."
|
||||||
|
|
||||||
DEFAULT_SYSTEM_MESSAGES = [
|
DEFAULT_SYSTEM_MESSAGES = [
|
||||||
SystemMessage("""You are a helpful AI assistant.
|
SystemMessage(
|
||||||
|
content="""You are a helpful AI assistant.
|
||||||
Solve tasks using your coding and language skills.
|
Solve tasks using your coding and language skills.
|
||||||
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
In the following cases, suggest python code (in a python coding block) or shell script (in a sh coding block) for the user to execute.
|
||||||
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
1. When you need to collect info, use the code to output the info you need, for example, browse or search the web, download/read a file, print the content of a webpage or a file, get the current date/time, check the operating system. After sufficient info is printed and the task is ready to be solved based on your language skill, you can solve the task by yourself.
|
||||||
@ -31,7 +32,8 @@ When using code, you must indicate the script type in the code block. The user c
|
|||||||
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
|
If you want the user to save the code in a file before executing it, put # filename: <filename> inside the code block as the first line. Don't include multiple code blocks in one response. Do not ask users to copy and paste the result. Instead, use 'print' function for the output when relevant. Check the execution result returned by the user.
|
||||||
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
If the result indicates there is an error, fix the error and output the code again. Suggest the full code instead of partial code or code changes. If the error can't be fixed or if the task is not solved even after the code is executed successfully, analyze the problem, revisit your assumption, collect additional info you need, and think of a different approach to try.
|
||||||
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
When you find an answer, verify the answer carefully. Include verifiable evidence in your response if possible.
|
||||||
Reply "TERMINATE" in the end when everything is done.""")
|
Reply "TERMINATE" in the end when everything is done."""
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -23,9 +23,11 @@ class FileSurfer(BaseWorker):
|
|||||||
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
DEFAULT_DESCRIPTION = "An agent that can handle local files."
|
||||||
|
|
||||||
DEFAULT_SYSTEM_MESSAGES = [
|
DEFAULT_SYSTEM_MESSAGES = [
|
||||||
SystemMessage("""
|
SystemMessage(
|
||||||
|
content="""
|
||||||
You are a helpful AI Assistant.
|
You are a helpful AI Assistant.
|
||||||
When given a user query, use available functions to help the user with their request."""),
|
When given a user query, use available functions to help the user with their request."""
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -817,7 +817,7 @@ When deciding between tools, consider if the request can be best addressed by:
|
|||||||
for line in re.split(r"([\r\n]+)", page_markdown):
|
for line in re.split(r"([\r\n]+)", page_markdown):
|
||||||
message = UserMessage(
|
message = UserMessage(
|
||||||
# content=[
|
# content=[
|
||||||
prompt + buffer + line,
|
content=prompt + buffer + line,
|
||||||
# ag_image,
|
# ag_image,
|
||||||
# ],
|
# ],
|
||||||
source=self.metadata["type"],
|
source=self.metadata["type"],
|
||||||
|
|||||||
@ -47,7 +47,7 @@ class LedgerOrchestrator(BaseOrchestrator):
|
|||||||
It uses a ledger (implemented as a JSON generated by the LLM) to keep track of task progress and select the next agent that should speak."""
|
It uses a ledger (implemented as a JSON generated by the LLM) to keep track of task progress and select the next agent that should speak."""
|
||||||
|
|
||||||
DEFAULT_SYSTEM_MESSAGES = [
|
DEFAULT_SYSTEM_MESSAGES = [
|
||||||
SystemMessage(ORCHESTRATOR_SYSTEM_MESSAGE),
|
SystemMessage(content=ORCHESTRATOR_SYSTEM_MESSAGE),
|
||||||
]
|
]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user