mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-20 14:42:33 +00:00
Register returns AgentType (#382)
This commit is contained in:
parent
e1a823fb6d
commit
29088d67a4
@ -12,7 +12,7 @@ from dataclasses import dataclass
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
from typing import Any, Awaitable, Callable, DefaultDict, Dict, List, Mapping, ParamSpec, Set, Type, TypeVar, cast
|
||||||
|
|
||||||
from agnext.core import Subscription, TopicId
|
from agnext.core import AgentType, Subscription, TopicId
|
||||||
|
|
||||||
from ..core import (
|
from ..core import (
|
||||||
Agent,
|
Agent,
|
||||||
@ -445,10 +445,11 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
|
||||||
) -> None:
|
) -> AgentType:
|
||||||
if type in self._agent_factories:
|
if type in self._agent_factories:
|
||||||
raise ValueError(f"Agent with type {type} already exists.")
|
raise ValueError(f"Agent with type {type} already exists.")
|
||||||
self._agent_factories[type] = agent_factory
|
self._agent_factories[type] = agent_factory
|
||||||
|
return AgentType(type)
|
||||||
|
|
||||||
async def _invoke_agent_factory(
|
async def _invoke_agent_factory(
|
||||||
self,
|
self,
|
||||||
|
@ -28,7 +28,7 @@ import grpc
|
|||||||
from grpc.aio import StreamStreamCall
|
from grpc.aio import StreamStreamCall
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from agnext.core import MESSAGE_TYPE_REGISTRY, MessageContext, Subscription, TopicId
|
from agnext.core import MESSAGE_TYPE_REGISTRY, AgentType, MessageContext, Subscription, TopicId
|
||||||
|
|
||||||
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken
|
from ..core import Agent, AgentId, AgentInstantiationContext, AgentMetadata, AgentRuntime, CancellationToken
|
||||||
from .protos import AgentId as AgentIdProto
|
from .protos import AgentId as AgentIdProto
|
||||||
@ -352,7 +352,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
agent_factory: Callable[[], T | Awaitable[T]],
|
agent_factory: Callable[[], T | Awaitable[T]],
|
||||||
) -> None:
|
) -> AgentType:
|
||||||
if type in self._agent_factories:
|
if type in self._agent_factories:
|
||||||
raise ValueError(f"Agent with type {type} already exists.")
|
raise ValueError(f"Agent with type {type} already exists.")
|
||||||
self._agent_factories[type] = agent_factory
|
self._agent_factories[type] = agent_factory
|
||||||
@ -361,6 +361,7 @@ class WorkerAgentRuntime(AgentRuntime):
|
|||||||
message = Message(registerAgentType=RegisterAgentType(type=type))
|
message = Message(registerAgentType=RegisterAgentType(type=type))
|
||||||
await self._host_connection.send(message)
|
await self._host_connection.send(message)
|
||||||
logger.info("Sent registerAgentType message for %s", type)
|
logger.info("Sent registerAgentType message for %s", type)
|
||||||
|
return AgentType(type)
|
||||||
|
|
||||||
async def _invoke_agent_factory(
|
async def _invoke_agent_factory(
|
||||||
self,
|
self,
|
||||||
|
@ -9,6 +9,7 @@ from ._agent_metadata import AgentMetadata
|
|||||||
from ._agent_props import AgentChildren
|
from ._agent_props import AgentChildren
|
||||||
from ._agent_proxy import AgentProxy
|
from ._agent_proxy import AgentProxy
|
||||||
from ._agent_runtime import AgentRuntime
|
from ._agent_runtime import AgentRuntime
|
||||||
|
from ._agent_type import AgentType
|
||||||
from ._base_agent import BaseAgent
|
from ._base_agent import BaseAgent
|
||||||
from ._cancellation_token import CancellationToken
|
from ._cancellation_token import CancellationToken
|
||||||
from ._message_context import MessageContext
|
from ._message_context import MessageContext
|
||||||
@ -33,4 +34,5 @@ __all__ = [
|
|||||||
"Subscription",
|
"Subscription",
|
||||||
"MessageContext",
|
"MessageContext",
|
||||||
"Serialization",
|
"Serialization",
|
||||||
|
"AgentType",
|
||||||
]
|
]
|
||||||
|
@ -1,8 +1,13 @@
|
|||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ._agent_type import AgentType
|
||||||
|
|
||||||
|
|
||||||
class AgentId:
|
class AgentId:
|
||||||
def __init__(self, type: str, key: str) -> None:
|
def __init__(self, type: str | AgentType, key: str) -> None:
|
||||||
|
if isinstance(type, AgentType):
|
||||||
|
type = type.type
|
||||||
|
|
||||||
if type.isidentifier() is False:
|
if type.isidentifier() is False:
|
||||||
raise ValueError(f"Invalid type: {type}")
|
raise ValueError(f"Invalid type: {type}")
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any, Awaitable, Callable, Mapping, Protocol, Type, TypeVar, r
|
|||||||
from ._agent import Agent
|
from ._agent import Agent
|
||||||
from ._agent_id import AgentId
|
from ._agent_id import AgentId
|
||||||
from ._agent_metadata import AgentMetadata
|
from ._agent_metadata import AgentMetadata
|
||||||
|
from ._agent_type import AgentType
|
||||||
from ._cancellation_token import CancellationToken
|
from ._cancellation_token import CancellationToken
|
||||||
from ._subscription import Subscription
|
from ._subscription import Subscription
|
||||||
from ._topic import TopicId
|
from ._topic import TopicId
|
||||||
@ -70,7 +71,7 @@ class AgentRuntime(Protocol):
|
|||||||
self,
|
self,
|
||||||
type: str,
|
type: str,
|
||||||
agent_factory: Callable[[], T | Awaitable[T]],
|
agent_factory: Callable[[], T | Awaitable[T]],
|
||||||
) -> None:
|
) -> AgentType:
|
||||||
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
|
"""Register an agent factory with the runtime associated with a specific type. The type must be unique.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
7
python/src/agnext/core/_agent_type.py
Normal file
7
python/src/agnext/core/_agent_type.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=True, frozen=True)
|
||||||
|
class AgentType:
|
||||||
|
type: str
|
||||||
|
"""String representation of this agent type."""
|
Loading…
x
Reference in New Issue
Block a user