mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-30 00:30:23 +00:00
Simplify handler decorator (#50)
* Simplify handler decorator * add more tests * mypy * formatting * fix 3.10 and improve type handling of decorator * test fix * format
This commit is contained in:
parent
ad513d5017
commit
8cb530f65e
@ -16,7 +16,7 @@ class Inner(TypeRoutedAgent):
|
||||
def __init__(self, name: str, router: AgentRuntime) -> None:
|
||||
super().__init__(name, router)
|
||||
|
||||
@message_handler(MessageType)
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
return MessageType(body=f"Inner: {message.body}", sender=self.name)
|
||||
|
||||
@ -26,7 +26,7 @@ class Outer(TypeRoutedAgent):
|
||||
super().__init__(name, router)
|
||||
self._inner = inner
|
||||
|
||||
@message_handler(MessageType)
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
inner_response = self._send_message(message, self._inner)
|
||||
inner_message = await inner_response
|
||||
|
||||
@ -38,17 +38,17 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
self._chat_messages: List[Message] = []
|
||||
self._function_executor = function_executor
|
||||
|
||||
@message_handler(TextMessage)
|
||||
@message_handler()
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
# Add a user message.
|
||||
self._chat_messages.append(message)
|
||||
|
||||
@message_handler(Reset)
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Reset the chat messages.
|
||||
self._chat_messages = []
|
||||
|
||||
@message_handler(RespondNow)
|
||||
@message_handler()
|
||||
async def on_respond_now(
|
||||
self, message: RespondNow, cancellation_token: CancellationToken
|
||||
) -> TextMessage | FunctionCallMessage:
|
||||
@ -101,7 +101,7 @@ class ChatCompletionAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
# Return the response.
|
||||
return final_response
|
||||
|
||||
@message_handler(FunctionCallMessage)
|
||||
@message_handler()
|
||||
async def on_tool_call_message(
|
||||
self, message: FunctionCallMessage, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResultMessage:
|
||||
|
||||
@ -24,7 +24,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
self._assistant_id = assistant_id
|
||||
self._thread_id = thread_id
|
||||
|
||||
@message_handler(TextMessage)
|
||||
@message_handler()
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
|
||||
# Save the message to the thread.
|
||||
_ = await self._client.beta.threads.messages.create(
|
||||
@ -34,7 +34,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
metadata={"sender": message.source},
|
||||
)
|
||||
|
||||
@message_handler(Reset)
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
# Get all messages in this thread.
|
||||
all_msgs: List[str] = []
|
||||
@ -52,7 +52,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
||||
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
|
||||
assert status.deleted is True
|
||||
|
||||
@message_handler(RespondNow)
|
||||
@message_handler()
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
||||
# Handle response format.
|
||||
if message.response_format == ResponseFormat.json_object:
|
||||
|
||||
@ -35,16 +35,16 @@ class GroupChat(BaseChatAgent, TypeRoutedAgent):
|
||||
agent_sublists = [agent.subscriptions for agent in self._agents]
|
||||
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
|
||||
|
||||
@message_handler(Reset)
|
||||
@message_handler()
|
||||
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
||||
self._history.clear()
|
||||
|
||||
@message_handler(RespondNow)
|
||||
@message_handler()
|
||||
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
|
||||
return self._output.get_output()
|
||||
|
||||
@message_handler(TextMessage)
|
||||
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
|
||||
@message_handler()
|
||||
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> Any:
|
||||
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
|
||||
# Should this class disallow it?
|
||||
|
||||
|
||||
@ -34,7 +34,7 @@ class OrchestratorChat(BaseChatAgent, TypeRoutedAgent):
|
||||
def children(self) -> Sequence[str]:
|
||||
return [agent.name for agent in self._specialists] + [self._orchestrator.name] + [self._planner.name]
|
||||
|
||||
@message_handler(TextMessage)
|
||||
@message_handler()
|
||||
async def on_text_message(
|
||||
self,
|
||||
message: TextMessage,
|
||||
|
||||
18
src/agnext/chat/patterns/two_agent_chat.py
Normal file
18
src/agnext/chat/patterns/two_agent_chat.py
Normal file
@ -0,0 +1,18 @@
|
||||
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
|
||||
|
||||
from ...core import AgentRuntime
|
||||
from ..agents.base import BaseChatAgent
|
||||
|
||||
|
||||
class TwoAgentChat(GroupChat):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
runtime: AgentRuntime,
|
||||
agent1: BaseChatAgent,
|
||||
agent2: BaseChatAgent,
|
||||
num_rounds: int,
|
||||
output: GroupChatOutput,
|
||||
) -> None:
|
||||
super().__init__(name, description, runtime, [agent1, agent2], num_rounds, output)
|
||||
@ -1,28 +1,132 @@
|
||||
from typing import Any, Callable, Coroutine, Dict, NoReturn, Sequence, Type, TypeVar
|
||||
import logging
|
||||
from functools import wraps
|
||||
from types import NoneType, UnionType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
Literal,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
get_type_hints,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from agnext.core import AgentRuntime, BaseAgent, CancellationToken
|
||||
from agnext.core.exceptions import CantHandleException
|
||||
|
||||
ReceivesT = TypeVar("ReceivesT")
|
||||
logger = logging.getLogger("agnext")
|
||||
|
||||
ReceivesT = TypeVar("ReceivesT", contravariant=True)
|
||||
ProducesT = TypeVar("ProducesT", covariant=True)
|
||||
|
||||
# TODO: Generic typevar bound binding U to agent type
|
||||
# Can't do because python doesnt support it
|
||||
|
||||
|
||||
def is_union(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Union or origin is UnionType
|
||||
|
||||
|
||||
def is_optional(t: object) -> bool:
|
||||
origin = get_origin(t)
|
||||
return origin is Optional
|
||||
|
||||
|
||||
# Special type to avoid the 3.10 vs 3.11+ difference of typing._SpecialForm vs typing.Any
|
||||
class AnyType:
|
||||
pass
|
||||
|
||||
|
||||
def get_types(t: object) -> Sequence[Type[Any]] | None:
|
||||
if is_union(t):
|
||||
return get_args(t)
|
||||
elif is_optional(t):
|
||||
return tuple(list(get_args(t)) + [NoneType])
|
||||
elif t is Any:
|
||||
return (AnyType,)
|
||||
elif isinstance(t, type):
|
||||
return (t,)
|
||||
elif isinstance(t, NoneType):
|
||||
return (NoneType,)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class MessageHandler(Protocol[ReceivesT, ProducesT]):
|
||||
target_types: Sequence[type]
|
||||
produces_types: Sequence[type]
|
||||
is_message_handler: Literal[True]
|
||||
|
||||
async def __call__(self, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT: ...
|
||||
|
||||
|
||||
# NOTE: this works on concrete types and not inheritance
|
||||
# TODO: Use a protocl for the outer function to check checked arg names
|
||||
def message_handler(
|
||||
*target_types: Type[ReceivesT],
|
||||
strict: bool = True,
|
||||
) -> Callable[
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
|
||||
Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
[Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]]],
|
||||
MessageHandler[ReceivesT, ProducesT],
|
||||
]:
|
||||
def decorator(
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
|
||||
) -> Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
|
||||
func: Callable[[Any, ReceivesT, CancellationToken], Coroutine[Any, Any, ProducesT]],
|
||||
) -> MessageHandler[ReceivesT, ProducesT]:
|
||||
type_hints = get_type_hints(func)
|
||||
if "message" not in type_hints:
|
||||
raise AssertionError("message parameter not found in function signature")
|
||||
|
||||
if "return" not in type_hints:
|
||||
raise AssertionError("return not found in function signature")
|
||||
|
||||
# Get the type of the message parameter
|
||||
target_types = get_types(type_hints["message"])
|
||||
if target_types is None:
|
||||
raise AssertionError("Message type not found")
|
||||
|
||||
print(type_hints)
|
||||
return_types = get_types(type_hints["return"])
|
||||
|
||||
if return_types is None:
|
||||
raise AssertionError("Return type not found")
|
||||
|
||||
# Convert target_types to list and stash
|
||||
func._target_types = list(target_types) # type: ignore
|
||||
return func
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: Any, message: ReceivesT, cancellation_token: CancellationToken) -> ProducesT:
|
||||
if strict:
|
||||
if type(message) not in target_types:
|
||||
raise CantHandleException(f"Message type {type(message)} not in target types {target_types}")
|
||||
else:
|
||||
logger.warning(f"Message type {type(message)} not in target types {target_types}")
|
||||
|
||||
return_value = await func(self, message, cancellation_token)
|
||||
|
||||
if strict:
|
||||
if return_value is not AnyType and type(return_value) not in return_types:
|
||||
raise ValueError(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
elif return_value is not AnyType:
|
||||
logger.warning(f"Return type {type(return_value)} not in return types {return_types}")
|
||||
|
||||
return return_value
|
||||
|
||||
wrapper_handler = cast(MessageHandler[ReceivesT, ProducesT], wrapper)
|
||||
wrapper_handler.target_types = list(target_types)
|
||||
wrapper_handler.produces_types = list(return_types)
|
||||
wrapper_handler.is_message_handler = True
|
||||
|
||||
return wrapper_handler
|
||||
|
||||
return decorator
|
||||
|
||||
@ -35,9 +139,10 @@ class TypeRoutedAgent(BaseAgent):
|
||||
for attr in dir(self):
|
||||
if callable(getattr(self, attr, None)):
|
||||
handler = getattr(self, attr)
|
||||
if hasattr(handler, "_target_types"):
|
||||
for target_type in handler._target_types:
|
||||
self._handlers[target_type] = handler
|
||||
if hasattr(handler, "is_message_handler"):
|
||||
message_handler = cast(MessageHandler[Any, Any], handler)
|
||||
for target_type in message_handler.target_types:
|
||||
self._handlers[target_type] = message_handler
|
||||
|
||||
super().__init__(name, router)
|
||||
|
||||
|
||||
@ -22,7 +22,7 @@ class LongRunningAgent(TypeRoutedAgent):
|
||||
self.called = False
|
||||
self.cancelled = False
|
||||
|
||||
@message_handler(MessageType)
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
sleep = asyncio.ensure_future(asyncio.sleep(100))
|
||||
@ -41,7 +41,7 @@ class NestingLongRunningAgent(TypeRoutedAgent):
|
||||
self.cancelled = False
|
||||
self._nested_agent = nested_agent
|
||||
|
||||
@message_handler(MessageType)
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.called = True
|
||||
response = self._send_message(message, self._nested_agent, cancellation_token=cancellation_token)
|
||||
|
||||
@ -19,7 +19,7 @@ class LoopbackAgent(TypeRoutedAgent):
|
||||
self.num_calls = 0
|
||||
|
||||
|
||||
@message_handler(MessageType)
|
||||
@message_handler()
|
||||
async def on_new_message(self, message: MessageType, cancellation_token: CancellationToken) -> MessageType:
|
||||
self.num_calls += 1
|
||||
return message
|
||||
|
||||
39
tests/test_types.py
Normal file
39
tests/test_types.py
Normal file
@ -0,0 +1,39 @@
|
||||
from types import NoneType
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from agnext.components.type_routed_agent import AnyType, get_types, message_handler
|
||||
from agnext.core import CancellationToken
|
||||
|
||||
|
||||
def test_get_types() -> None:
|
||||
assert get_types(Union[int, str]) == (int, str)
|
||||
assert get_types(int | str) == (int, str)
|
||||
assert get_types(int) == (int,)
|
||||
assert get_types(str) == (str,)
|
||||
assert get_types("test") is None
|
||||
assert get_types(Optional[int]) == (int, NoneType)
|
||||
assert get_types(NoneType) == (NoneType,)
|
||||
assert get_types(None) == (NoneType,)
|
||||
|
||||
|
||||
def test_handler() -> None:
|
||||
|
||||
class HandlerClass:
|
||||
@message_handler()
|
||||
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
|
||||
return None
|
||||
|
||||
@message_handler()
|
||||
async def handler2(self, message: str | bool, cancellation_token: CancellationToken) -> None:
|
||||
return None
|
||||
|
||||
assert HandlerClass.handler.target_types == [int]
|
||||
assert HandlerClass.handler.produces_types == [AnyType]
|
||||
|
||||
assert HandlerClass.handler2.target_types == [str, bool]
|
||||
assert HandlerClass.handler2.produces_types == [NoneType]
|
||||
|
||||
class HandlerClass:
|
||||
@message_handler()
|
||||
async def handler(self, message: int, cancellation_token: CancellationToken) -> Any:
|
||||
return None
|
||||
Loading…
x
Reference in New Issue
Block a user