doc & fix: Enhance AgentInstantiationContext with detailed documentation and examples for agent instantiation; Fix a but that caused value error when the expected class is not provided in register_factory (#5555)

Resolves #5519

Also spotted and fixed a bug that caused value error from `register_factory`, when the `expected_class` was not provided.
This commit is contained in:
Eric Zhu 2025-02-14 18:19:32 -08:00 committed by GitHub
parent 69c0b2b5ef
commit 80891b4841
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 116 additions and 5 deletions

View File

@ -7,6 +7,81 @@ from ._agent_runtime import AgentRuntime
class AgentInstantiationContext:
"""A static class that provides context for agent instantiation.
This static class can be used to access the current runtime and agent ID
during agent instantiation -- inside the factory function or the agent's
class constructor.
Example:
Get the current runtime and agent ID inside the factory function and
the agent's constructor:
.. code-block:: python
import asyncio
from dataclasses import dataclass
from autogen_core import (
AgentId,
AgentInstantiationContext,
MessageContext,
RoutedAgent,
SingleThreadedAgentRuntime,
message_handler,
)
@dataclass
class TestMessage:
content: str
class TestAgent(RoutedAgent):
def __init__(self, description: str):
super().__init__(description)
# Get the current runtime -- we don't use it here, but it's available.
_ = AgentInstantiationContext.current_runtime()
# Get the current agent ID.
agent_id = AgentInstantiationContext.current_agent_id()
print(f"Current AgentID from constructor: {agent_id}")
@message_handler
async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None:
print(f"Received message: {message.content}")
def test_agent_factory() -> TestAgent:
# Get the current runtime -- we don't use it here, but it's available.
_ = AgentInstantiationContext.current_runtime()
# Get the current agent ID.
agent_id = AgentInstantiationContext.current_agent_id()
print(f"Current AgentID from factory: {agent_id}")
return TestAgent(description="Test agent")
async def main() -> None:
# Create a SingleThreadedAgentRuntime instance.
runtime = SingleThreadedAgentRuntime()
# Start the runtime.
runtime.start()
# Register the agent type with a factory function.
await runtime.register_factory("test_agent", test_agent_factory)
# Send a message to the agent. The runtime will instantiate the agent and call the message handler.
await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default"))
# Stop the runtime.
await runtime.stop()
asyncio.run(main())
"""
def __init__(self) -> None:
raise RuntimeError(
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."

View File

@ -126,7 +126,7 @@ class AgentRuntime(Protocol):
Args:
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None.
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed.
"""
...

View File

@ -676,10 +676,20 @@ class SingleThreadedAgentRuntime(AgentRuntime):
.. code-block:: python
import asyncio
from autogen_core import SingleThreadedAgentRuntime
runtime = SingleThreadedAgentRuntime()
runtime.start()
async def main() -> None:
runtime = SingleThreadedAgentRuntime()
runtime.start()
# ... do other things ...
await runtime.stop()
asyncio.run(main())
"""
if self._run_context is not None:
@ -765,7 +775,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance

View File

@ -32,6 +32,31 @@ def tracer_provider() -> TracerProvider:
return get_test_tracer_provider(test_exporter)
@pytest.mark.asyncio
async def test_agent_type_register_factory() -> None:
runtime = SingleThreadedAgentRuntime()
def agent_factory() -> NoopAgent:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId("name1", "default")
agent = NoopAgent()
assert agent.id == id
return agent
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
with pytest.raises(ValueError):
# This should fail because the expected class does not match the actual class.
await runtime.register_factory(
type=AgentType("name1"),
agent_factory=agent_factory, # type: ignore
expected_class=CascadingAgent,
)
# Without expected_class, no error.
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory)
@pytest.mark.asyncio
async def test_agent_type_must_be_unique() -> None:
runtime = SingleThreadedAgentRuntime()

View File

@ -723,7 +723,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
else:
agent_instance = maybe_agent_instance
if type_func_alias(agent_instance) != expected_class:
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
raise ValueError("Factory registered using the wrong type.")
return agent_instance

View File

@ -52,6 +52,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
)
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
await worker.register_factory(type=AgentType("name5"), agent_factory=lambda: NoopAgent())
await worker.stop()
await host.stop()