Add ability to register Agent instances (#6131)

<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?

Nice to have functionality

## Related issue number

Closes #6060 

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
peterychang 2025-05-12 11:34:48 -04:00 committed by GitHub
parent c26d894c34
commit 9118f9b998
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 401 additions and 26 deletions

View File

@ -1,9 +1,13 @@
from typing import Any, Mapping, Protocol, runtime_checkable
from typing import TYPE_CHECKING, Any, Mapping, Protocol, runtime_checkable
from ._agent_id import AgentId
from ._agent_metadata import AgentMetadata
from ._message_context import MessageContext
# Forward declaration for type checking only
if TYPE_CHECKING:
from ._agent_runtime import AgentRuntime
@runtime_checkable
class Agent(Protocol):
@ -17,6 +21,15 @@ class Agent(Protocol):
"""ID of the agent."""
...
async def bind_id_and_runtime(self, id: AgentId, runtime: "AgentRuntime") -> None:
"""Function used to bind an Agent instance to an `AgentRuntime`.
Args:
agent_id (AgentId): ID of the agent.
runtime (AgentRuntime): AgentRuntime instance to bind the agent to.
"""
...
async def on_message(self, message: Any, ctx: MessageContext) -> Any:
"""Message handler for the agent. This should only be called by the runtime, not by other agents.

View File

@ -118,3 +118,9 @@ class AgentInstantiationContext:
raise RuntimeError(
"AgentInstantiationContext.agent_id() must be called within an instantiation context such as when the AgentRuntime is instantiating an agent. Mostly likely this was caused by directly instantiating an agent instead of using the AgentRuntime to do so."
) from e
@classmethod
def is_in_factory_call(cls) -> bool:
if cls._AGENT_INSTANTIATION_CONTEXT_VAR.get(None) is None:
return False
return True

View File

@ -130,6 +130,60 @@ class AgentRuntime(Protocol):
"""
...
async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
"""Register an agent instance with the runtime. The type may be reused, but each agent_id must be unique. All agent instances within a type must be of the same object type. This API does not add any subscriptions.
.. note::
This is a low level API and usually the agent class's `register_instance` method should be used instead, as this also handles subscriptions automatically.
Example:
.. code-block:: python
from dataclasses import dataclass
from autogen_core import AgentId, AgentRuntime, MessageContext, RoutedAgent, event
from autogen_core.models import UserMessage
@dataclass
class MyMessage:
content: str
class MyAgent(RoutedAgent):
def __init__(self) -> None:
super().__init__("My core agent")
@event
async def handler(self, message: UserMessage, context: MessageContext) -> None:
print("Event received: ", message.content)
async def main() -> None:
runtime: AgentRuntime = ... # type: ignore
agent = MyAgent()
await runtime.register_agent_instance(
agent_instance=agent, agent_id=AgentId(type="my_agent", key="default")
)
import asyncio
asyncio.run(main())
Args:
agent_instance (Agent): A concrete instance of the agent.
agent_id (AgentId): The agent's identifier. The agent's type is `agent_id.type`.
"""
...
# TODO: uncomment out the following type ignore when this is fixed in mypy: https://github.com/python/mypy/issues/3737
async def try_get_underlying_agent_instance(self, id: AgentId, type: Type[T] = Agent) -> T: # type: ignore[assignment]
"""Try to get the underlying agent instance by name and namespace. This is generally discouraged (hence the long name), but can be useful in some cases.

View File

@ -21,6 +21,7 @@ from ._subscription import Subscription, UnboundSubscription
from ._subscription_context import SubscriptionInstantiationContext
from ._topic import TopicId
from ._type_prefix_subscription import TypePrefixSubscription
from ._type_subscription import TypeSubscription
T = TypeVar("T", bound=Agent)
@ -82,20 +83,25 @@ class BaseAgent(ABC, Agent):
return AgentMetadata(key=self._id.key, type=self._id.type, description=self._description)
def __init__(self, description: str) -> None:
try:
runtime = AgentInstantiationContext.current_runtime()
id = AgentInstantiationContext.current_agent_id()
except LookupError as e:
raise RuntimeError(
"BaseAgent must be instantiated within the context of an AgentRuntime. It cannot be directly instantiated."
) from e
self._runtime: AgentRuntime = runtime
self._id: AgentId = id
if AgentInstantiationContext.is_in_factory_call():
self._runtime: AgentRuntime = AgentInstantiationContext.current_runtime()
self._id = AgentInstantiationContext.current_agent_id()
if not isinstance(description, str):
raise ValueError("Agent description must be a string")
self._description = description
async def bind_id_and_runtime(self, id: AgentId, runtime: AgentRuntime) -> None:
if hasattr(self, "_id"):
if self._id != id:
raise RuntimeError("Agent is already bound to a different ID")
if hasattr(self, "_runtime"):
if self._runtime != runtime:
raise RuntimeError("Agent is already bound to a different runtime")
self._id = id
self._runtime = runtime
@property
def type(self) -> str:
return self.id.type
@ -155,6 +161,56 @@ class BaseAgent(ABC, Agent):
async def close(self) -> None:
pass
async def register_instance(
self,
runtime: AgentRuntime,
agent_id: AgentId,
*,
skip_class_subscriptions: bool = True,
skip_direct_message_subscription: bool = False,
) -> AgentId:
"""
This function is similar to `register` but is used for registering an instance of an agent. A subscription based on the agent ID is created and added to the runtime.
"""
agent_id = await runtime.register_agent_instance(agent_instance=self, agent_id=agent_id)
id_subscription = TypeSubscription(topic_type=agent_id.key, agent_type=agent_id.type)
await runtime.add_subscription(id_subscription)
if not skip_class_subscriptions:
with SubscriptionInstantiationContext.populate_context(AgentType(agent_id.type)):
subscriptions: List[Subscription] = []
for unbound_subscription in self._unbound_subscriptions():
subscriptions_list_result = unbound_subscription()
if inspect.isawaitable(subscriptions_list_result):
subscriptions_list = await subscriptions_list_result
else:
subscriptions_list = subscriptions_list_result
subscriptions.extend(subscriptions_list)
for subscription in subscriptions:
await runtime.add_subscription(subscription)
if not skip_direct_message_subscription:
# Additionally adds a special prefix subscription for this agent to receive direct messages
try:
await runtime.add_subscription(
TypePrefixSubscription(
# The prefix MUST include ":" to avoid collisions with other agents
topic_type_prefix=agent_id.type + ":",
agent_type=agent_id.type,
)
)
except ValueError:
# We don't care if the subscription already exists
pass
# TODO: deduplication
for _message_type, serializer in self._handles_types():
runtime.add_message_serializer(serializer)
return agent_id
@classmethod
async def register(
cls,

View File

@ -266,6 +266,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
self._serialization_registry = SerializationRegistry()
self._ignore_unhandled_handler_exceptions = ignore_unhandled_exceptions
self._background_exception: BaseException | None = None
self._agent_instance_types: Dict[str, Type[Agent]] = {}
@property
def unprocessed_messages_count(
@ -909,6 +910,32 @@ class SingleThreadedAgentRuntime(AgentRuntime):
return type
async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> Agent:
raise RuntimeError(
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
)
if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")
if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
@ -930,8 +957,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
agent = cast(T, await agent)
return agent
except BaseException as e:

View File

@ -9,7 +9,7 @@ async def test_base_agent_create(mocker: MockerFixture) -> None:
runtime = mocker.Mock(spec=AgentRuntime)
# Shows how to set the context for the agent instantiation in a test context
with AgentInstantiationContext.populate_context((runtime, AgentId("name", "namespace"))):
agent = NoopAgent()
assert agent.runtime == runtime
assert agent.id == AgentId("name", "namespace")
with AgentInstantiationContext.populate_context((runtime, AgentId("name2", "namespace2"))):
agent2 = NoopAgent()
assert agent2.runtime == runtime
assert agent2.id == AgentId("name2", "namespace2")

View File

@ -82,6 +82,60 @@ async def test_agent_type_must_be_unique() -> None:
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory, expected_class=NoopAgent)
@pytest.mark.asyncio
async def test_agent_type_register_instance() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent2_id = AgentId(type="name", key="notdefault")
agent1 = NoopAgent()
agent1_dup = NoopAgent()
agent2 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
await agent2.register_instance(runtime=runtime, agent_id=agent2_id)
assert await runtime.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
assert await runtime.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
with pytest.raises(ValueError):
await agent1_dup.register_instance(runtime=runtime, agent_id=agent1_id)
@pytest.mark.asyncio
async def test_agent_type_register_instance_different_types() -> None:
runtime = SingleThreadedAgentRuntime()
agent_id1 = AgentId(type="name", key="noop")
agent_id2 = AgentId(type="name", key="loopback")
agent1 = NoopAgent()
agent2 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id1)
with pytest.raises(ValueError):
await agent2.register_instance(runtime=runtime, agent_id=agent_id2)
@pytest.mark.asyncio
async def test_agent_type_register_instance_publish_new_source() -> None:
runtime = SingleThreadedAgentRuntime(ignore_unhandled_exceptions=False)
agent_id = AgentId(type="name", key="default")
agent1 = LoopbackAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent_id)
await runtime.add_subscription(TypeSubscription("notdefault", "name"))
runtime.start()
with pytest.raises(RuntimeError):
await runtime.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
await runtime.stop_when_idle()
await runtime.close()
@pytest.mark.asyncio
async def test_register_instance_factory() -> None:
runtime = SingleThreadedAgentRuntime()
agent1_id = AgentId(type="name", key="default")
agent1 = NoopAgent()
await agent1.register_instance(runtime=runtime, agent_id=agent1_id)
with pytest.raises(ValueError):
await NoopAgent.register(runtime, "name", lambda: NoopAgent())
@pytest.mark.asyncio
async def test_register_receives_publish(tracer_provider: TracerProvider) -> None:
runtime = SingleThreadedAgentRuntime(tracer_provider=tracer_provider)

View File

@ -251,6 +251,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
self._subscription_manager = SubscriptionManager()
self._serialization_registry = SerializationRegistry()
self._extra_grpc_config = extra_grpc_config or []
self._agent_instance_types: Dict[str, Type[Agent]] = {}
if payload_serialization_format not in {JSON_DATA_CONTENT_TYPE, PROTOBUF_DATA_CONTENT_TYPE}:
raise ValueError(f"Unsupported payload serialization format: {payload_serialization_format}")
@ -701,6 +702,14 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
except BaseException as e:
logger.error("Error handling event", exc_info=e)
async def _register_agent_type(self, agent_type: str) -> None:
if self._host_connection is None:
raise RuntimeError("Host connection is not set.")
message = agent_worker_pb2.RegisterAgentTypeRequest(type=agent_type)
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
async def register_factory(
self,
type: str | AgentType,
@ -729,14 +738,38 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
return agent_instance
self._agent_factories[type.type] = factory_wrapper
# Send the registration request message to the host.
message = agent_worker_pb2.RegisterAgentTypeRequest(type=type.type)
_response: agent_worker_pb2.RegisterAgentTypeResponse = await self._host_connection.stub.RegisterAgent(
message, metadata=self._host_connection.metadata
)
await self._register_agent_type(type.type)
return type
async def register_agent_instance(
self,
agent_instance: Agent,
agent_id: AgentId,
) -> AgentId:
def agent_factory() -> Agent:
raise RuntimeError(
"Agent factory was invoked for an agent instance that was not registered. This is likely due to the agent type being incorrectly subscribed to a topic. If this exception occurs when publishing a message to the DefaultTopicId, then it is likely that `skip_class_subscriptions` needs to be turned off when registering the agent."
)
if agent_id in self._instantiated_agents:
raise ValueError(f"Agent with id {agent_id} already exists.")
if agent_id.type not in self._agent_factories:
self._agent_factories[agent_id.type] = agent_factory
await self._register_agent_type(agent_id.type)
self._agent_instance_types[agent_id.type] = type_func_alias(agent_instance)
else:
if self._agent_factories[agent_id.type].__code__ != agent_factory.__code__:
raise ValueError("Agent factories and agent instances cannot be registered to the same type.")
if self._agent_instance_types[agent_id.type] != type_func_alias(agent_instance):
raise ValueError("Agent instances must be the same object type.")
await agent_instance.bind_id_and_runtime(id=agent_id, runtime=self)
self._instantiated_agents[agent_id] = agent_instance
return agent_id
async def _invoke_agent_factory(
self,
agent_factory: Callable[[], T | Awaitable[T]] | Callable[[AgentRuntime, AgentId], T | Awaitable[T]],
@ -757,7 +790,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
raise ValueError("Agent factory must take 0 or 2 arguments.")
if inspect.isawaitable(agent):
return cast(T, await agent)
agent = cast(T, await agent)
return agent

View File

@ -3,7 +3,7 @@ import logging
import os
from datetime import datetime
from typing import Any, AsyncGenerator, List, Type, Union
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
import pytest
from autogen_core import CancellationToken, FunctionCall, Image
@ -570,7 +570,7 @@ def thought_with_tool_call_stream_client(monkeypatch: pytest.MonkeyPatch) -> Azu
)
mock_client = MagicMock()
mock_client.close = MagicMock()
mock_client.close = AsyncMock()
async def mock_complete(*args: Any, **kwargs: Any) -> Any:
if kwargs.get("stream", False):

View File

@ -577,6 +577,139 @@ async def test_grpc_max_message_size() -> None:
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_agent_type_register_instance() -> None:
host_address = "localhost:50051"
agent1_id = AgentId(type="name", key="default")
agentdup_id = AgentId(type="name", key="default")
agent2_id = AgentId(type="name", key="notdefault")
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
agent1 = NoopAgent()
agent2 = NoopAgent()
agentdup = NoopAgent()
await worker.start()
await worker.register_agent_instance(agent1, agent_id=agent1_id)
await worker.register_agent_instance(agent2, agent_id=agent2_id)
with pytest.raises(ValueError):
await worker.register_agent_instance(agentdup, agent_id=agentdup_id)
assert await worker.try_get_underlying_agent_instance(agent1_id, type=NoopAgent) == agent1
assert await worker.try_get_underlying_agent_instance(agent2_id, type=NoopAgent) == agent2
await worker.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_agent_type_register_instance_different_types() -> None:
host_address = "localhost:50051"
agent1_id = AgentId(type="name", key="noop")
agent2_id = AgentId(type="name", key="loopback")
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
agent1 = NoopAgent()
agent2 = LoopbackAgent()
await worker.start()
await worker.register_agent_instance(agent1, agent_id=agent1_id)
with pytest.raises(ValueError):
await worker.register_agent_instance(agent2, agent_id=agent2_id)
await worker.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_register_instance_factory() -> None:
host_address = "localhost:50051"
agent1_id = AgentId(type="name", key="default")
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
agent1 = NoopAgent()
await worker.start()
await agent1.register_instance(runtime=worker, agent_id=agent1_id)
with pytest.raises(ValueError):
await NoopAgent.register(runtime=worker, type="name", factory=lambda: NoopAgent())
await worker.stop()
await host.stop()
@pytest.mark.grpc
@pytest.mark.asyncio
async def test_instance_factory_messaging() -> None:
host_address = "localhost:50051"
loopback_agent_id = AgentId(type="dm_agent", key="dm_agent")
cascading_agent_id = AgentId(type="instance_agent", key="instance_agent")
host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()
worker = GrpcWorkerAgentRuntime(host_address=host_address)
cascading_agent = CascadingAgent(max_rounds=5)
loopback_agent = LoopbackAgent()
await worker.start()
await loopback_agent.register_instance(worker, agent_id=loopback_agent_id)
resp = await worker.send_message(message=ContentMessage(content="Hello!"), recipient=loopback_agent_id)
assert resp == ContentMessage(content="Hello!")
await cascading_agent.register_instance(worker, agent_id=cascading_agent_id)
await CascadingAgent.register(worker, "factory_agent", lambda: CascadingAgent(max_rounds=5))
# instance_agent will publish a message that factory_agent will pick up
for i in range(5):
await worker.publish_message(
CascadingMessageType(round=i + 1), TopicId(type="instance_agent", source="instance_agent")
)
await asyncio.sleep(2)
agent = await worker.try_get_underlying_agent_instance(AgentId("factory_agent", "default"), CascadingAgent)
assert agent.num_calls == 4
assert cascading_agent.num_calls == 5
await worker.stop()
await host.stop()
# GrpcWorkerAgentRuntimeHost eats exceptions in the main loop
# @pytest.mark.grpc
# @pytest.mark.asyncio
# async def test_agent_type_register_instance_publish_new_source() -> None:
# host_address = "localhost:50056"
# agent_id = AgentId(type="name", key="default")
# agent1 = LoopbackAgent()
# host = GrpcWorkerAgentRuntimeHost(address=host_address)
# host.start()
# worker = GrpcWorkerAgentRuntime(host_address=host_address)
# await worker.start()
# publisher = GrpcWorkerAgentRuntime(host_address=host_address)
# publisher.add_message_serializer(try_get_known_serializers_for_type(MessageType))
# await publisher.start()
# await agent1.register_instance(worker, agent_id=agent_id)
# await worker.add_subscription(TypeSubscription("notdefault", "name"))
# with pytest.raises(RuntimeError):
# await worker.publish_message(MessageType(), TopicId("notdefault", "notdefault"))
# await asyncio.sleep(2)
# await worker.stop()
# await host.stop()
if __name__ == "__main__":
os.environ["GRPC_VERBOSITY"] = "DEBUG"
os.environ["GRPC_TRACE"] = "all"

View File

@ -176,7 +176,7 @@ async def test_invalid_request(test_config: ComponentModel, test_server: None) -
config.config["host"] = "fake"
tool = HttpTool.load_component(config)
with pytest.raises(httpx.ConnectError):
with pytest.raises((httpx.ConnectError, httpx.ConnectTimeout)):
await tool.run_json({"query": "test query", "value": 42}, CancellationToken())