mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
doc & fix: Enhance AgentInstantiationContext with detailed documentation and examples for agent instantiation; Fix a but that caused value error when the expected class is not provided in register_factory (#5555)
Resolves #5519 Also spotted and fixed a bug that caused value error from `register_factory`, when the `expected_class` was not provided.
This commit is contained in:
parent
69c0b2b5ef
commit
80891b4841
@ -7,6 +7,81 @@ from ._agent_runtime import AgentRuntime
|
||||
|
||||
|
||||
class AgentInstantiationContext:
|
||||
"""A static class that provides context for agent instantiation.
|
||||
|
||||
This static class can be used to access the current runtime and agent ID
|
||||
during agent instantiation -- inside the factory function or the agent's
|
||||
class constructor.
|
||||
|
||||
Example:
|
||||
|
||||
Get the current runtime and agent ID inside the factory function and
|
||||
the agent's constructor:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
from autogen_core import (
|
||||
AgentId,
|
||||
AgentInstantiationContext,
|
||||
MessageContext,
|
||||
RoutedAgent,
|
||||
SingleThreadedAgentRuntime,
|
||||
message_handler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestMessage:
|
||||
content: str
|
||||
|
||||
|
||||
class TestAgent(RoutedAgent):
|
||||
def __init__(self, description: str):
|
||||
super().__init__(description)
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from constructor: {agent_id}")
|
||||
|
||||
@message_handler
|
||||
async def handle_test_message(self, message: TestMessage, ctx: MessageContext) -> None:
|
||||
print(f"Received message: {message.content}")
|
||||
|
||||
|
||||
def test_agent_factory() -> TestAgent:
|
||||
# Get the current runtime -- we don't use it here, but it's available.
|
||||
_ = AgentInstantiationContext.current_runtime()
|
||||
# Get the current agent ID.
|
||||
agent_id = AgentInstantiationContext.current_agent_id()
|
||||
print(f"Current AgentID from factory: {agent_id}")
|
||||
return TestAgent(description="Test agent")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a SingleThreadedAgentRuntime instance.
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
# Start the runtime.
|
||||
runtime.start()
|
||||
|
||||
# Register the agent type with a factory function.
|
||||
await runtime.register_factory("test_agent", test_agent_factory)
|
||||
|
||||
# Send a message to the agent. The runtime will instantiate the agent and call the message handler.
|
||||
await runtime.send_message(TestMessage(content="Hello, world!"), AgentId("test_agent", "default"))
|
||||
|
||||
# Stop the runtime.
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise RuntimeError(
|
||||
"AgentInstantiationContext cannot be instantiated. It is a static class that provides context management for agent instantiation."
|
||||
|
||||
@ -126,7 +126,7 @@ class AgentRuntime(Protocol):
|
||||
Args:
|
||||
type (str): The type of agent this factory creates. It is not the same as agent class name. The `type` parameter is used to differentiate between different factory functions rather than agent classes.
|
||||
agent_factory (Callable[[], T]): The factory that creates the agent, where T is a concrete Agent type. Inside the factory, use `autogen_core.AgentInstantiationContext` to access variables like the current runtime and agent ID.
|
||||
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None.
|
||||
expected_class (type[T] | None, optional): The expected class of the agent, used for runtime validation of the factory. Defaults to None. If None, no validation is performed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@ -676,10 +676,20 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_core import SingleThreadedAgentRuntime
|
||||
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
|
||||
async def main() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
runtime.start()
|
||||
|
||||
# ... do other things ...
|
||||
|
||||
await runtime.stop()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
if self._run_context is not None:
|
||||
@ -765,7 +775,7 @@ class SingleThreadedAgentRuntime(AgentRuntime):
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
@ -32,6 +32,31 @@ def tracer_provider() -> TracerProvider:
|
||||
return get_test_tracer_provider(test_exporter)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_register_factory() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
def agent_factory() -> NoopAgent:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId("name1", "default")
|
||||
agent = NoopAgent()
|
||||
assert agent.id == id
|
||||
return agent
|
||||
|
||||
await runtime.register_factory(type=AgentType("name1"), agent_factory=agent_factory, expected_class=NoopAgent)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# This should fail because the expected class does not match the actual class.
|
||||
await runtime.register_factory(
|
||||
type=AgentType("name1"),
|
||||
agent_factory=agent_factory, # type: ignore
|
||||
expected_class=CascadingAgent,
|
||||
)
|
||||
|
||||
# Without expected_class, no error.
|
||||
await runtime.register_factory(type=AgentType("name2"), agent_factory=agent_factory)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_type_must_be_unique() -> None:
|
||||
runtime = SingleThreadedAgentRuntime()
|
||||
|
||||
@ -723,7 +723,7 @@ class GrpcWorkerAgentRuntime(AgentRuntime):
|
||||
else:
|
||||
agent_instance = maybe_agent_instance
|
||||
|
||||
if type_func_alias(agent_instance) != expected_class:
|
||||
if expected_class is not None and type_func_alias(agent_instance) != expected_class:
|
||||
raise ValueError("Factory registered using the wrong type.")
|
||||
|
||||
return agent_instance
|
||||
|
||||
@ -52,6 +52,7 @@ async def test_agent_types_must_be_unique_single_worker() -> None:
|
||||
)
|
||||
|
||||
await worker.register_factory(type=AgentType("name4"), agent_factory=lambda: NoopAgent(), expected_class=NoopAgent)
|
||||
await worker.register_factory(type=AgentType("name5"), agent_factory=lambda: NoopAgent())
|
||||
|
||||
await worker.stop()
|
||||
await host.stop()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user