diff --git a/python/packages/autogen-core/src/autogen_core/_closure_agent.py b/python/packages/autogen-core/src/autogen_core/_closure_agent.py index 03206d18f..8f93b4f2b 100644 --- a/python/packages/autogen-core/src/autogen_core/_closure_agent.py +++ b/python/packages/autogen-core/src/autogen_core/_closure_agent.py @@ -1,7 +1,8 @@ from __future__ import annotations import inspect -from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints +import warnings +from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints from ._agent_id import AgentId from ._agent_instantiation import AgentInstantiationContext @@ -73,7 +74,11 @@ class ClosureContext(Protocol): class ClosureAgent(BaseAgent, ClosureContext): def __init__( - self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]] + self, + description: str, + closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], + *, + unknown_type_policy: Literal["error", "warn", "ignore"] = "warn", ) -> None: try: runtime = AgentInstantiationContext.current_runtime() @@ -89,6 +94,7 @@ class ClosureAgent(BaseAgent, ClosureContext): handled_types = get_handled_types_from_closure(closure) self._expected_types = handled_types self._closure = closure + self._unknown_type_policy = unknown_type_policy super().__init__(description) @property @@ -110,9 +116,17 @@ class ClosureAgent(BaseAgent, ClosureContext): async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any: if type(message) not in self._expected_types: - raise CantHandleException( - f"Message type {type(message)} not in target types {self._expected_types} of {self.id}" - ) + if self._unknown_type_policy == "warn": + warnings.warn( + f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.", + stacklevel=1, + ) + return None + elif self._unknown_type_policy == "error": + raise CantHandleException( + f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning." + ) + return await self._closure(self, message, ctx) async def save_state(self) -> Mapping[str, Any]: @@ -130,19 +144,77 @@ class ClosureAgent(BaseAgent, ClosureContext): type: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]], *, - skip_class_subscriptions: bool = False, + unknown_type_policy: Literal["error", "warn", "ignore"] = "warn", skip_direct_message_subscription: bool = False, description: str = "", subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None, ) -> AgentType: - def factory() -> ClosureAgent: - return ClosureAgent(description=description, closure=closure) + """The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime. + The closure can define the type of message which is expected, or `Any` can be used to accept any type of message. + + Example: + + .. code-block:: python + + import asyncio + from autogen_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext + from dataclasses import dataclass + + from autogen_core._default_subscription import DefaultSubscription + from autogen_core._default_topic import DefaultTopicId + + + @dataclass + class MyMessage: + content: str + + + async def main(): + queue = asyncio.Queue[MyMessage]() + + async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None: + await queue.put(message) + + runtime = SingleThreadedAgentRuntime() + await ClosureAgent.register_closure( + runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()] + ) + + runtime.start() + await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId()) + await runtime.stop_when_idle() + + result = await queue.get() + print(result) + + + asyncio.run(main()) + + + Args: + runtime (AgentRuntime): Runtime to register the agent to + type (str): Agent type of registered agent + closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages + unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn". + skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False. + description (str, optional): Description of what agent does. Defaults to "". + subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None. + + Returns: + AgentType: Type of the agent that was registered + """ + + def factory() -> ClosureAgent: + return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy) + + assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions" agent_type = await cls.register( runtime=runtime, type=type, factory=factory, # type: ignore - skip_class_subscriptions=skip_class_subscriptions, + # There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s + skip_class_subscriptions=True, skip_direct_message_subscription=skip_direct_message_subscription, )