From 29088d67a47c433b6985aa7ae6f9adbdef71ce6b Mon Sep 17 00:00:00 2001 From: Jack Gerrits Date: Tue, 20 Aug 2024 17:38:36 -0400 Subject: [PATCH] Register returns AgentType (#382) --- .../agnext/application/_single_threaded_agent_runtime.py | 5 +++-- python/src/agnext/application/_worker_runtime.py | 5 +++-- python/src/agnext/core/__init__.py | 2 ++ python/src/agnext/core/_agent_id.py | 7 ++++++- python/src/agnext/core/_agent_runtime.py | 3 ++- python/src/agnext/core/_agent_type.py | 7 +++++++ 6 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 python/src/agnext/core/_agent_type.py diff --git a/python/src/agnext/application/_single_threaded_agent_runtime.py b/python/src/agnext/application/_single_threaded_agent_runtime.py index 5cfd3e5d4..87eab51fa 100644 --- a/python/src/agnext/application/_single_threaded_agent_runtime.py +++ b/python/src/agnext/application/_single_threaded_agent_runtime.py @@ -12,7 +12,7 @@ from dataclasses import dataclass from enum import Enum from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast -from agnext.core import Subscription, TopicId +from agnext.core import AgentType, Subscription, TopicId from ..core import ( Agent, @@ -445,10 +445,11 @@ class SingleThreadedAgentRuntime(AgentRuntime): self, type: str, agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]], - ) -> None: + ) -> AgentType: if type in self._agent_factories: raise ValueError(f"Agent with type {type} already exists.") self._agent_factories[type] = agent_factory + return AgentType(type) async def _invoke_agent_factory( self, diff --git a/python/src/agnext/application/_worker_runtime.py b/python/src/agnext/application/_worker_runtime.py index 42e58bc63..ed101aa0d 100644 --- a/python/src/agnext/application/_worker_runtime.py +++ b/python/src/agnext/application/_worker_runtime.py @@ -28,7 +28,7 @@ import grpc from grpc.aio import StreamStreamCall from typing_extensions import Self -from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId +from agnext.core import MESSAGE_TYPE_REGISTRY, AgentType, MessageContext, Subscription, TopicId from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken from .protos import AgentId as AgentIdProto @@ -352,7 +352,7 @@ class WorkerAgentRuntime(AgentRuntime): self, type: str, agent_factory: Callable[[], T | Awaitable[T]], - ) -> None: + ) -> AgentType: if type in self._agent_factories: raise ValueError(f"Agent with type {type} already exists.") self._agent_factories[type] = agent_factory @@ -361,6 +361,7 @@ class WorkerAgentRuntime(AgentRuntime): message = Message(registerAgentType=RegisterAgentType(type=type)) await self._host_connection.send(message) logger.info("Sent registerAgentType message for %s", type) + return AgentType(type) async def _invoke_agent_factory( self, diff --git a/python/src/agnext/core/__init__.py b/python/src/agnext/core/__init__.py index 851a4f347..1c78ead2f 100644 --- a/python/src/agnext/core/__init__.py +++ b/python/src/agnext/core/__init__.py @@ -9,6 +9,7 @@ from ._agent_metadata import AgentMetadata from ._agent_props import AgentChildren from ._agent_proxy import AgentProxy from ._agent_runtime import AgentRuntime +from ._agent_type import AgentType from ._base_agent import BaseAgent from ._cancellation_token import CancellationToken from ._message_context import MessageContext @@ -33,4 +34,5 @@ __all__ = [ "Subscription", "MessageContext", "Serialization", + "AgentType", ] diff --git a/python/src/agnext/core/_agent_id.py b/python/src/agnext/core/_agent_id.py index a1d9ae6d5..1459f250f 100644 --- a/python/src/agnext/core/_agent_id.py +++ b/python/src/agnext/core/_agent_id.py @@ -1,8 +1,13 @@ from typing_extensions import Self +from ._agent_type import AgentType + class AgentId: - def __init__(self, type: str, key: str) -> None: + def __init__(self, type: str | AgentType, key: str) -> None: + if isinstance(type, AgentType): + type = type.type + if type.isidentifier() is False: raise ValueError(f"Invalid type: {type}") diff --git a/python/src/agnext/core/_agent_runtime.py b/python/src/agnext/core/_agent_runtime.py index 0f05f2f2d..33368c5dc 100644 --- a/python/src/agnext/core/_agent_runtime.py +++ b/python/src/agnext/core/_agent_runtime.py @@ -5,6 +5,7 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r from ._agent import Agent from ._agent_id import AgentId from ._agent_metadata import AgentMetadata +from ._agent_type import AgentType from ._cancellation_token import CancellationToken from ._subscription import Subscription from ._topic import TopicId @@ -70,7 +71,7 @@ class AgentRuntime(Protocol): self, type: str, agent_factory: Callable[[], T | Awaitable[T]], - ) -> None: + ) -> AgentType: """Register an agent factory with the runtime associated with a specific type. The type must be unique. Args: diff --git a/python/src/agnext/core/_agent_type.py b/python/src/agnext/core/_agent_type.py new file mode 100644 index 000000000..009f8c9c4 --- /dev/null +++ b/python/src/agnext/core/_agent_type.py @@ -0,0 +1,7 @@ +from dataclasses import dataclass + + +@dataclass(eq=True, frozen=True) +class AgentType: + type: str + """String representation of this agent type."""