From c29218b329e56de9e3ad12a96cc307b6b710884a Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Mon, 17 Jun 2024 12:43:51 -0400 Subject: [PATCH] Add agent proxy (#84) --- .../_single_threaded_agent_runtime.py | 6 +++ src/agnext/core/__init__.py | 3 +- src/agnext/core/_agent_id.py | 26 ++++++++++ src/agnext/core/_agent_proxy.py | 50 +++++++++++++++++++ src/agnext/core/_agent_runtime.py | 4 ++ 5 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 src/agnext/core/_agent_id.py create mode 100644 src/agnext/core/_agent_proxy.py diff --git a/src/agnext/application/_single_threaded_agent_runtime.py b/src/agnext/application/_single_threaded_agent_runtime.py index 932e35fe7..9dcbc7e7b 100644 --- a/src/agnext/application/_single_threaded_agent_runtime.py +++ b/src/agnext/application/_single_threaded_agent_runtime.py @@ -302,3 +302,9 @@ class SingleThreadedAgentRuntime(AgentRuntime): def agent_metadata(self, agent: Agent) -> AgentMetadata: return agent.metadata + + def agent_save_state(self, agent: Agent) -> Mapping[str, Any]: + return agent.save_state() + + def agent_load_state(self, agent: Agent, state: Mapping[str, Any]) -> None: + agent.load_state(state) diff --git a/src/agnext/core/__init__.py b/src/agnext/core/__init__.py index 605e33bf3..01b3d3456 100644 --- a/src/agnext/core/__init__.py +++ b/src/agnext/core/__init__.py @@ -3,10 +3,11 @@ The :mod:`agnext.core` module provides the foundational generic interfaces upon """ from ._agent import Agent +from ._agent_id import AgentId from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_runtime import AgentRuntime from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken -__all__ = ["Agent", "AgentMetadata", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"] +__all__ = ["Agent", "AgentId", "AgentMetadata", "AgentRuntime", "BaseAgent", "CancellationToken", "AgentChildren"] diff --git a/src/agnext/core/_agent_id.py b/src/agnext/core/_agent_id.py new file mode 100644 index 000000000..a329bbcce --- /dev/null +++ b/src/agnext/core/_agent_id.py @@ -0,0 +1,26 @@ +from typing_extensions import Self + + +class AgentId: + def __init__(self, name: str, namespace: str) -> None: + self._name = name + self._namespace = namespace + + def __str__(self) -> str: + return f"{self._namespace}/{self._name}" + + def __hash__(self) -> int: + return hash((self._namespace, self._name)) + + @classmethod + def from_str(cls, agent_id: str) -> Self: + namespace, name = agent_id.split("/") + return cls(name, namespace) + + @property + def namespace(self) -> str: + return self._namespace + + @property + def name(self) -> str: + return self._name diff --git a/src/agnext/core/_agent_proxy.py b/src/agnext/core/_agent_proxy.py new file mode 100644 index 000000000..78cefdce3 --- /dev/null +++ b/src/agnext/core/_agent_proxy.py @@ -0,0 +1,50 @@ +from asyncio import Future +from typing import Any, Mapping + +from ._agent import Agent +from ._agent_id import AgentId +from ._agent_metadata import AgentMetadata +from ._agent_runtime import AgentRuntime +from ._cancellation_token import CancellationToken + + +class AgentProxy: + def __init__(self, agent: Agent, runtime: AgentRuntime): + self._agent = agent + self._runtime = runtime + + @property + def id(self) -> AgentId: + """Target agent for this proxy""" + raise NotImplementedError + + @property + def metadata(self) -> AgentMetadata: + """Metadata of the agent.""" + return self._runtime.agent_metadata(self._agent) + + def send_message( + self, + message: Any, + *, + sender: Agent, + cancellation_token: CancellationToken | None = None, + ) -> Future[Any]: + return self._runtime.send_message( + message, + recipient=self._agent, + sender=sender, + cancellation_token=cancellation_token, + ) + + def save_state(self) -> Mapping[str, Any]: + """Save the state of the agent. The result must be JSON serializable.""" + return self._runtime.agent_save_state(self._agent) + + def load_state(self, state: Mapping[str, Any]) -> None: + """Load in the state of the agent obtained from `save_state`. + + Args: + state (Mapping[str, Any]): State of the agent. Must be JSON serializable. + """ + self._runtime.agent_load_state(self._agent, state) diff --git a/src/agnext/core/_agent_runtime.py b/src/agnext/core/_agent_runtime.py index 2c8e6a569..964ed2416 100644 --- a/src/agnext/core/_agent_runtime.py +++ b/src/agnext/core/_agent_runtime.py @@ -44,3 +44,7 @@ class AgentRuntime(Protocol): def load_state(self, state: Mapping[str, Any]) -> None: ... def agent_metadata(self, agent: Agent) -> AgentMetadata: ... + + def agent_save_state(self, agent: Agent) -> Mapping[str, Any]: ... + + def agent_load_state(self, agent: Agent, state: Mapping[str, Any]) -> None: ...