mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-22 15:41:56 +00:00
Allow closure agent to ignore unknown messages, add docs (#4836)
Allow closure agent to ignore unknown messages
This commit is contained in:
parent
2819515220
commit
a5681d73c6
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any, Awaitable, Callable, List, Mapping, Protocol, Sequence, TypeVar, get_type_hints
|
import warnings
|
||||||
|
from typing import Any, Awaitable, Callable, List, Literal, Mapping, Protocol, Sequence, TypeVar, get_type_hints
|
||||||
|
|
||||||
from ._agent_id import AgentId
|
from ._agent_id import AgentId
|
||||||
from ._agent_instantiation import AgentInstantiationContext
|
from ._agent_instantiation import AgentInstantiationContext
|
||||||
@ -73,7 +74,11 @@ class ClosureContext(Protocol):
|
|||||||
|
|
||||||
class ClosureAgent(BaseAgent, ClosureContext):
|
class ClosureAgent(BaseAgent, ClosureContext):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, description: str, closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]
|
self,
|
||||||
|
description: str,
|
||||||
|
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||||
|
*,
|
||||||
|
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
runtime = AgentInstantiationContext.current_runtime()
|
runtime = AgentInstantiationContext.current_runtime()
|
||||||
@ -89,6 +94,7 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
|||||||
handled_types = get_handled_types_from_closure(closure)
|
handled_types = get_handled_types_from_closure(closure)
|
||||||
self._expected_types = handled_types
|
self._expected_types = handled_types
|
||||||
self._closure = closure
|
self._closure = closure
|
||||||
|
self._unknown_type_policy = unknown_type_policy
|
||||||
super().__init__(description)
|
super().__init__(description)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -110,9 +116,17 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
|||||||
|
|
||||||
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
|
async def on_message_impl(self, message: Any, ctx: MessageContext) -> Any:
|
||||||
if type(message) not in self._expected_types:
|
if type(message) not in self._expected_types:
|
||||||
raise CantHandleException(
|
if self._unknown_type_policy == "warn":
|
||||||
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}"
|
warnings.warn(
|
||||||
)
|
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'error' to raise an exception, or 'ignore' to suppress this warning.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
elif self._unknown_type_policy == "error":
|
||||||
|
raise CantHandleException(
|
||||||
|
f"Message type {type(message)} not in target types {self._expected_types} of {self.id}. Set unknown_type_policy to 'warn' to suppress this exception, or 'ignore' to suppress this warning."
|
||||||
|
)
|
||||||
|
|
||||||
return await self._closure(self, message, ctx)
|
return await self._closure(self, message, ctx)
|
||||||
|
|
||||||
async def save_state(self) -> Mapping[str, Any]:
|
async def save_state(self) -> Mapping[str, Any]:
|
||||||
@ -130,19 +144,77 @@ class ClosureAgent(BaseAgent, ClosureContext):
|
|||||||
type: str,
|
type: str,
|
||||||
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
closure: Callable[[ClosureContext, T, MessageContext], Awaitable[Any]],
|
||||||
*,
|
*,
|
||||||
skip_class_subscriptions: bool = False,
|
unknown_type_policy: Literal["error", "warn", "ignore"] = "warn",
|
||||||
skip_direct_message_subscription: bool = False,
|
skip_direct_message_subscription: bool = False,
|
||||||
description: str = "",
|
description: str = "",
|
||||||
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
|
subscriptions: Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None = None,
|
||||||
) -> AgentType:
|
) -> AgentType:
|
||||||
def factory() -> ClosureAgent:
|
"""The closure agent allows you to define an agent using a closure, or function without needing to define a class. It allows values to be extracted out of the runtime.
|
||||||
return ClosureAgent(description=description, closure=closure)
|
|
||||||
|
|
||||||
|
The closure can define the type of message which is expected, or `Any` can be used to accept any type of message.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from autogen_core import SingleThreadedAgentRuntime, MessageContext, ClosureAgent, ClosureContext
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from autogen_core._default_subscription import DefaultSubscription
|
||||||
|
from autogen_core._default_topic import DefaultTopicId
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MyMessage:
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
queue = asyncio.Queue[MyMessage]()
|
||||||
|
|
||||||
|
async def output_result(_ctx: ClosureContext, message: MyMessage, ctx: MessageContext) -> None:
|
||||||
|
await queue.put(message)
|
||||||
|
|
||||||
|
runtime = SingleThreadedAgentRuntime()
|
||||||
|
await ClosureAgent.register_closure(
|
||||||
|
runtime, "output_result", output_result, subscriptions=lambda: [DefaultSubscription()]
|
||||||
|
)
|
||||||
|
|
||||||
|
runtime.start()
|
||||||
|
await runtime.publish_message(MyMessage("Hello, world!"), DefaultTopicId())
|
||||||
|
await runtime.stop_when_idle()
|
||||||
|
|
||||||
|
result = await queue.get()
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runtime (AgentRuntime): Runtime to register the agent to
|
||||||
|
type (str): Agent type of registered agent
|
||||||
|
closure (Callable[[ClosureContext, T, MessageContext], Awaitable[Any]]): Closure to handle messages
|
||||||
|
unknown_type_policy (Literal["error", "warn", "ignore"], optional): What to do if a type is encountered that does not match the closure type. Defaults to "warn".
|
||||||
|
skip_direct_message_subscription (bool, optional): Do not add direct message subscription for this agent. Defaults to False.
|
||||||
|
description (str, optional): Description of what agent does. Defaults to "".
|
||||||
|
subscriptions (Callable[[], list[Subscription] | Awaitable[list[Subscription]]] | None, optional): List of subscriptions for this closure agent. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AgentType: Type of the agent that was registered
|
||||||
|
"""
|
||||||
|
|
||||||
|
def factory() -> ClosureAgent:
|
||||||
|
return ClosureAgent(description=description, closure=closure, unknown_type_policy=unknown_type_policy)
|
||||||
|
|
||||||
|
assert len(cls._unbound_subscriptions()) == 0, "Closure agents are expected to have no class subscriptions"
|
||||||
agent_type = await cls.register(
|
agent_type = await cls.register(
|
||||||
runtime=runtime,
|
runtime=runtime,
|
||||||
type=type,
|
type=type,
|
||||||
factory=factory, # type: ignore
|
factory=factory, # type: ignore
|
||||||
skip_class_subscriptions=skip_class_subscriptions,
|
# There should be no need to process class subscriptions, as the closure agent does not have any subscriptions.s
|
||||||
|
skip_class_subscriptions=True,
|
||||||
skip_direct_message_subscription=skip_direct_message_subscription,
|
skip_direct_message_subscription=skip_direct_message_subscription,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user